From bbd49211639b15c8aa080a5b33bdd5be98339444 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 14 Sep 2024 22:25:11 +0200 Subject: [PATCH] Update. --- grids.py | 21 ++++++++++++++++----- main.py | 48 ++++++++++++++++++++++++++---------------------- 2 files changed, 42 insertions(+), 27 deletions(-) diff --git a/grids.py b/grids.py index 7754c43..882c113 100755 --- a/grids.py +++ b/grids.py @@ -440,9 +440,16 @@ class Grids(problem.Problem): ) if delta: + u = (A != f_A).long() + img_delta_A = self.add_frame(self.grid2img(u), frame[None, :], thickness=1) + img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as( + img_delta_A + ) u = (B != f_B).long() - img_delta = self.add_frame(self.grid2img(u), frame[None, :], thickness=1) - img_delta = img_delta.min(dim=1, keepdim=True).values.expand_as(img_delta) + img_delta_B = self.add_frame(self.grid2img(u), frame[None, :], thickness=1) + img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as( + img_delta_B + ) img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1) img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1) @@ -484,9 +491,13 @@ class Grids(problem.Problem): img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2) if delta: - img_delta = self.add_frame(img_delta, colors[:, 0], thickness=8) - img_delta = self.add_frame(img_delta, white[None, :], thickness=2) - img = torch.cat([img_A, img_f_A, img_B, img_f_B, img_delta], dim=3) + img_delta_A = self.add_frame(img_delta_A, colors[:, 0], thickness=8) + img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2) + img_delta_B = self.add_frame(img_delta_B, colors[:, 0], thickness=8) + img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2) + img = torch.cat( + [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3 + ) else: img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3) diff --git a/main.py b/main.py index b751374..01ce963 100755 --- a/main.py +++ b/main.py @@ -63,7 +63,7 @@ parser.add_argument("--nb_train_alien_samples", type=int, default=0) parser.add_argument("--nb_test_alien_samples", type=int, default=0) -parser.add_argument("--nb_c_quizzes", type=int, default=2500) +parser.add_argument("--nb_c_quizzes", type=int, default=10000) parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) @@ -957,7 +957,7 @@ for i in range(args.nb_models): dropout=args.dropout, ).to(main_device) - model = torch.compile(model) + # model = torch.compile(model) model.id = i model.test_accuracy = 0.0 @@ -1347,7 +1347,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): # -------------------------------------------------------------------- - if min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes: + lowest_test_accuracy = min([float(m.test_accuracy) for m in models]) + + if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes: if c_quizzes is None: save_models(models, "naive") @@ -1355,14 +1357,10 @@ for n_epoch in range(current_epoch, args.nb_epochs): nb_gpus = len(gpus) nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus - args = [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus] - - # Ugly hack: Only one thread during the first epoch so that - # compilation of the model does not explode - if n_epoch == 0: - args = args[:1] - - c_quizzes, agreements = multithread_execution(generate_ae_c_quizzes, args) + c_quizzes, agreements = multithread_execution( + generate_ae_c_quizzes, + [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus], + ) save_c_quizzes_with_scores( models, @@ -1378,6 +1376,16 @@ for n_epoch in range(current_epoch, args.nb_epochs): solvable_only=True, ) + u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, 1:] + i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices + + save_c_quizzes_with_scores( + models, + c_quizzes[i][:256], + f"culture_c_quiz_{n_epoch:04d}_solvable_high_delta.png", + solvable_only=True, + ) + log_string(f"generated_c_quizzes {c_quizzes.size()=}") for model in models: @@ -1395,17 +1403,13 @@ for n_epoch in range(current_epoch, args.nb_epochs): # None if c_quizzes is None else c_quizzes[agreements[:, model.id]], - args = [ - (model, quiz_machine, n_epoch, c_quizzes, gpu) - for model, gpu in zip(weakest_models, gpus) - ] - - # Ugly hack: Only one thread during the first epoch so that - # compilation of the model does not explode - if n_epoch == 0: - args = args[:1] - - multithread_execution(one_ae_epoch, args) + multithread_execution( + one_ae_epoch, + [ + (model, quiz_machine, n_epoch, c_quizzes, gpu) + for model, gpu in zip(weakest_models, gpus) + ], + ) # -------------------------------------------------------------------- -- 2.39.5