From 2c68a24dccf22c153c4ece8584c1725c4be72720 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 1 Sep 2024 13:59:51 +0200 Subject: [PATCH] Update. --- grids.py | 1 + main.py | 25 ++++++++++++++----------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/grids.py b/grids.py index 9441811..9372922 100755 --- a/grids.py +++ b/grids.py @@ -393,6 +393,7 @@ class Grids(problem.Problem): if delta: u = (B != f_B).long() img_delta = self.add_frame(self.grid2img(u), frame[None, :], thickness=1) + img_delta = img_delta.min(dim=1, keepdim=True).values.expand_as(img_delta) img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1) img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1) diff --git a/main.py b/main.py index 2731f25..e533802 100755 --- a/main.py +++ b/main.py @@ -1311,7 +1311,7 @@ for i in range(args.nb_models): def c_quiz_criterion_one_good_one_bad(probas): - return (probas.max(dim=1).values >= 0.8) & (probas.min(dim=1).values <= 0.2) + return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25) def c_quiz_criterion_diff(probas): @@ -1323,8 +1323,8 @@ def c_quiz_criterion_diff2(probas): return (v[:, -2] - v[:, 0]) >= 0.5 -def c_quiz_criterion_two_certains(probas): - return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5) +def c_quiz_criterion_two_good(probas): + return ((probas >= 0.5).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.2) def c_quiz_criterion_some(probas): @@ -1336,10 +1336,10 @@ def c_quiz_criterion_some(probas): def generate_ae_c_quizzes(models, local_device=main_device): criteria = [ c_quiz_criterion_one_good_one_bad, - c_quiz_criterion_diff, + # c_quiz_criterion_diff, # c_quiz_criterion_diff2, - c_quiz_criterion_two_certains, - c_quiz_criterion_some, + # c_quiz_criterion_two_good, + # c_quiz_criterion_some, ] for m in models: @@ -1357,8 +1357,11 @@ def generate_ae_c_quizzes(models, local_device=main_device): duration_max = 4 * 3600 - wanted_nb = 16 # 0000 - nb_to_save = 16 + # wanted_nb = 240 + # nb_to_save = 240 + + wanted_nb = args.nb_train_samples // 4 + nb_to_save = 128 with torch.autograd.no_grad(): records = [[] for _ in criteria] @@ -1369,11 +1372,10 @@ def generate_ae_c_quizzes(models, local_device=main_device): time.perf_counter() < start_time + duration_max and min([bag_len(bag) for bag in records]) < wanted_nb ): - bl = [bag_len(bag) for bag in records] - log_string(f"bag_len {bl}") - model = models[torch.randint(len(models), (1,)).item()] result = ae_generate(model, template, mask_generate) + bl = [bag_len(bag) for bag in records] + log_string(f"bag_len {bl} model {model.id}") to_keep = quiz_machine.problem.trivial(result) == False result = result[to_keep] @@ -1417,6 +1419,7 @@ def generate_ae_c_quizzes(models, local_device=main_device): # correct_parts=correct_parts, comments=comments, delta=True, + nrow=12, ) log_string(f"wrote {filename}") -- 2.39.5