Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 11:59:51 +0000 (13:59 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 11:59:51 +0000 (13:59 +0200)
grids.py
main.py

index 9441811..9372922 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -393,6 +393,7 @@ class Grids(problem.Problem):
         if delta:
             u = (B != f_B).long()
             img_delta = self.add_frame(self.grid2img(u), frame[None, :], thickness=1)
+            img_delta = img_delta.min(dim=1, keepdim=True).values.expand_as(img_delta)
 
         img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1)
         img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1)
diff --git a/main.py b/main.py
index 2731f25..e533802 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1311,7 +1311,7 @@ for i in range(args.nb_models):
 
 
 def c_quiz_criterion_one_good_one_bad(probas):
-    return (probas.max(dim=1).values >= 0.8) & (probas.min(dim=1).values <= 0.2)
+    return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25)
 
 
 def c_quiz_criterion_diff(probas):
@@ -1323,8 +1323,8 @@ def c_quiz_criterion_diff2(probas):
     return (v[:, -2] - v[:, 0]) >= 0.5
 
 
-def c_quiz_criterion_two_certains(probas):
-    return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5)
+def c_quiz_criterion_two_good(probas):
+    return ((probas >= 0.5).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.2)
 
 
 def c_quiz_criterion_some(probas):
@@ -1336,10 +1336,10 @@ def c_quiz_criterion_some(probas):
 def generate_ae_c_quizzes(models, local_device=main_device):
     criteria = [
         c_quiz_criterion_one_good_one_bad,
-        c_quiz_criterion_diff,
+        c_quiz_criterion_diff,
         # c_quiz_criterion_diff2,
-        c_quiz_criterion_two_certains,
-        c_quiz_criterion_some,
+        # c_quiz_criterion_two_good,
+        c_quiz_criterion_some,
     ]
 
     for m in models:
@@ -1357,8 +1357,11 @@ def generate_ae_c_quizzes(models, local_device=main_device):
 
     duration_max = 4 * 3600
 
-    wanted_nb = 16  # 0000
-    nb_to_save = 16
+    # wanted_nb = 240
+    # nb_to_save = 240
+
+    wanted_nb = args.nb_train_samples // 4
+    nb_to_save = 128
 
     with torch.autograd.no_grad():
         records = [[] for _ in criteria]
@@ -1369,11 +1372,10 @@ def generate_ae_c_quizzes(models, local_device=main_device):
             time.perf_counter() < start_time + duration_max
             and min([bag_len(bag) for bag in records]) < wanted_nb
         ):
-            bl = [bag_len(bag) for bag in records]
-            log_string(f"bag_len {bl}")
-
             model = models[torch.randint(len(models), (1,)).item()]
             result = ae_generate(model, template, mask_generate)
+            bl = [bag_len(bag) for bag in records]
+            log_string(f"bag_len {bl} model {model.id}")
 
             to_keep = quiz_machine.problem.trivial(result) == False
             result = result[to_keep]
@@ -1417,6 +1419,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
                 # correct_parts=correct_parts,
                 comments=comments,
                 delta=True,
+                nrow=12,
             )
 
             log_string(f"wrote {filename}")