From a46d2b46c88683dad8aed4ef048da46ce747d306 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 19 Sep 2024 23:14:15 +0200 Subject: [PATCH] Update. --- grids.py | 18 +++++++++--------- main.py | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/grids.py b/grids.py index fb31c7d..0613043 100755 --- a/grids.py +++ b/grids.py @@ -134,20 +134,20 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations): class Grids(problem.Problem): - # grid_gray = 64 - # thickness = 1 - # background_gray = 255 - # dots = False + grid_gray = 64 + thickness = 1 + background_gray = 255 + dots = False # grid_gray=240 # thickness=1 # background_gray=240 # dots = False - grid_gray = 200 - thickness = 0 - background_gray = 240 - dots = True + # grid_gray = 200 + # thickness = 0 + # background_gray = 240 + # dots = True named_colors = [ ("white", [background_gray, background_gray, background_gray]), @@ -1835,7 +1835,7 @@ if __name__ == "__main__": print(t.__name__) w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) - w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size()) + # w_quizzes[:5] = torch.randint(grids.vocabulary_size(), w_quizzes[:5].size()) grids.save_quizzes_as_image( "/tmp", diff --git a/main.py b/main.py index 5493b7d..52505de 100755 --- a/main.py +++ b/main.py @@ -250,7 +250,7 @@ assert args.nb_test_samples % args.batch_size == 0 ###################################################################### -def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): +def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): if c_quizzes is None: quizzes = problem.generate_w_quizzes(nb_samples) nb_w_quizzes = quizzes.size(0) @@ -486,7 +486,7 @@ def ae_generate(model, nb, local_device=main_device): def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): - quizzes = quiz_set( + quizzes = generate_quiz_set( args.nb_train_samples if train else args.nb_test_samples, c_quizzes, args.c_quiz_multiplier, @@ -559,7 +559,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): # Save some original world quizzes and the full prediction (the four grids) - quizzes = quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device) + quizzes = generate_quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device) problem.save_quizzes_as_image( args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes ) @@ -570,7 +570,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): # Save some images of the prediction results - quizzes = quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier) + quizzes = generate_quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier) imt_set = samples_for_prediction_imt(quizzes.to(local_device)) result = ae_predict(model, imt_set, local_device=local_device).to("cpu") masks = imt_set[:, 1].to("cpu") -- 2.39.5