From 46847c131e994d426264718defbd1377051bbf44 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 20 Sep 2024 08:38:30 +0200 Subject: [PATCH] Update. --- grids.py | 36 +++++++++++++++++++++--------------- main.py | 47 +++++++++++++++++++++++++++++------------------ 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/grids.py b/grids.py index 0613043..197eb5a 100755 --- a/grids.py +++ b/grids.py @@ -144,9 +144,9 @@ class Grids(problem.Problem): # background_gray=240 # dots = False - # grid_gray = 200 + # grid_gray = 192 # thickness = 0 - # background_gray = 240 + # background_gray = 255 # dots = True named_colors = [ @@ -287,7 +287,8 @@ class Grids(problem.Problem): ###################################################################### def vocabulary_size(self): - return self.nb_colors + warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning) + return self.nb_colors + 4 def grid2img(self, x, scale=15, grids=True): m = torch.logical_and(x >= 0, x < self.nb_colors).long() @@ -313,13 +314,12 @@ class Grids(problem.Problem): :, :, :, - scale // 2 - 2 : scale // 2 + 1, + scale // 2 - 1 : scale // 2 + 2, :, - scale // 2 - 2 : scale // 2 + 1, + scale // 2 - 1 : scale // 2 + 2, ] - z[...] = (z == self.background_gray) * self.grid_gray + ( - z != self.background_gray - ) * z + zz = (z == self.background_gray).min(dim=1, keepdim=True).values + z[...] = zz * self.grid_gray + (zz == False) * z for n in range(m.size(0)): for i in range(m.size(1)): @@ -367,7 +367,7 @@ class Grids(problem.Problem): comment_height=48, nrow=4, grids=True, - margin=8, + margin=12, delta=False, ): quizzes = quizzes.to("cpu") @@ -446,10 +446,12 @@ class Grids(problem.Problem): + (1 - predicted_parts[:, :, None]) * white[None, None, :] ) - img_A = self.add_frame(img_A, colors[:, 0], thickness=8) - img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=8) - img_B = self.add_frame(img_B, colors[:, 2], thickness=8) - img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=8) + separation = 6 + + img_A = self.add_frame(img_A, colors[:, 0], thickness=separation) + img_f_A = self.add_frame(img_f_A, colors[:, 1], thickness=separation) + img_B = self.add_frame(img_B, colors[:, 2], thickness=separation) + img_f_B = self.add_frame(img_f_B, colors[:, 3], thickness=separation) img_A = self.add_frame(img_A, white[None, :], thickness=2) img_f_A = self.add_frame(img_f_A, white[None, :], thickness=2) @@ -457,9 +459,13 @@ class Grids(problem.Problem): img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2) if delta: - img_delta_A = self.add_frame(img_delta_A, colors[:, 0], thickness=8) + img_delta_A = self.add_frame( + img_delta_A, colors[:, 0], thickness=separation + ) img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2) - img_delta_B = self.add_frame(img_delta_B, colors[:, 0], thickness=8) + img_delta_B = self.add_frame( + img_delta_B, colors[:, 0], thickness=separation + ) img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2) img = torch.cat( [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3 diff --git a/main.py b/main.py index 52505de..06dfc5e 100755 --- a/main.py +++ b/main.py @@ -272,9 +272,7 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): c_quizzes = c_quizzes[i] w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0)) - w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape( - w_quizzes.size(0), -1 - ) + quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) nb_w_quizzes = w_quizzes.size(0) nb_c_quizzes = c_quizzes.size(0) @@ -383,7 +381,9 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"): return torch.cat(record) -def predict_full(model, input, with_perturbations=False, local_device=main_device): +def predict_full( + model, input, with_noise=False, with_hints=False, local_device=main_device +): input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) nb = input.size(0) masks = input.new_zeros(input.size()) @@ -393,8 +393,10 @@ def predict_full(model, input, with_perturbations=False, local_device=main_devic input = (1 - masks) * targets imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) - if with_perturbations: + if with_hints: imt_set = add_hints_imt(imt_set) + + if with_noise: imt_set = add_noise_imt(imt_set) result = ae_predict(model, imt_set, local_device=local_device, desc=None) @@ -563,7 +565,13 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): problem.save_quizzes_as_image( args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes ) - result = predict_full(model=model, input=quizzes, local_device=local_device) + result = predict_full( + model=model, + input=quizzes, + with_noise=True, + with_hints=True, + local_device=local_device, + ) problem.save_quizzes_as_image( args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result ) @@ -630,27 +638,30 @@ def evaluate_quizzes(quizzes, models, local_device): result = predict_full( model=model, input=quizzes, - with_perturbations=True, + with_noise=False, + with_hints=True, local_device=local_device, ) - nb_correct += (max_nb_mistakes_on_one_grid(quizzes, result) == 0).long() + nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, result) + nb_correct += (nb_mistakes == 0).long() - result = predict_full( - model=model, - input=quizzes, - with_perturbations=False, - local_device=local_device, - ) + # result = predict_full( + # model=model, + # input=quizzes, + # with_noise=False, + # with_hints=False, + # local_device=local_device, + # ) - nb_wrong += ( - max_nb_mistakes_on_one_grid(quizzes, result) >= args.nb_mistakes_to_be_wrong - ).long() + nb_wrong += (nb_mistakes >= args.nb_mistakes_to_be_wrong).long() to_keep = (nb_correct >= args.nb_have_to_be_correct) & ( nb_wrong >= args.nb_have_to_be_wrong ) + # print("\n\n", nb_correct, nb_wrong) + return to_keep, nb_correct, nb_wrong @@ -659,7 +670,7 @@ def evaluate_quizzes(quizzes, models, local_device): def identity_quizzes(quizzes): quizzes = quizzes.reshape(quizzes.size(0), 4, -1) - return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & ( + return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values | ( quizzes[:, 2] == quizzes[:, 3] ).min(dim=1).values -- 2.39.5