Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 14 Sep 2024 20:25:11 +0000 (22:25 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 14 Sep 2024 20:25:11 +0000 (22:25 +0200)
grids.py
main.py

index 7754c43..882c113 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -440,9 +440,16 @@ class Grids(problem.Problem):
         )
 
         if delta:
+            u = (A != f_A).long()
+            img_delta_A = self.add_frame(self.grid2img(u), frame[None, :], thickness=1)
+            img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as(
+                img_delta_A
+            )
             u = (B != f_B).long()
-            img_delta = self.add_frame(self.grid2img(u), frame[None, :], thickness=1)
-            img_delta = img_delta.min(dim=1, keepdim=True).values.expand_as(img_delta)
+            img_delta_B = self.add_frame(self.grid2img(u), frame[None, :], thickness=1)
+            img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as(
+                img_delta_B
+            )
 
         img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1)
         img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1)
@@ -484,9 +491,13 @@ class Grids(problem.Problem):
         img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
 
         if delta:
-            img_delta = self.add_frame(img_delta, colors[:, 0], thickness=8)
-            img_delta = self.add_frame(img_delta, white[None, :], thickness=2)
-            img = torch.cat([img_A, img_f_A, img_B, img_f_B, img_delta], dim=3)
+            img_delta_A = self.add_frame(img_delta_A, colors[:, 0], thickness=8)
+            img_delta_A = self.add_frame(img_delta_A, white[None, :], thickness=2)
+            img_delta_B = self.add_frame(img_delta_B, colors[:, 0], thickness=8)
+            img_delta_B = self.add_frame(img_delta_B, white[None, :], thickness=2)
+            img = torch.cat(
+                [img_A, img_f_A, img_delta_A, img_B, img_f_B, img_delta_B], dim=3
+            )
         else:
             img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
 
diff --git a/main.py b/main.py
index b751374..01ce963 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -63,7 +63,7 @@ parser.add_argument("--nb_train_alien_samples", type=int, default=0)
 
 parser.add_argument("--nb_test_alien_samples", type=int, default=0)
 
-parser.add_argument("--nb_c_quizzes", type=int, default=2500)
+parser.add_argument("--nb_c_quizzes", type=int, default=10000)
 
 parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
 
@@ -957,7 +957,7 @@ for i in range(args.nb_models):
         dropout=args.dropout,
     ).to(main_device)
 
-    model = torch.compile(model)
+    model = torch.compile(model)
 
     model.id = i
     model.test_accuracy = 0.0
@@ -1347,7 +1347,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     # --------------------------------------------------------------------
 
-    if min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes:
+    lowest_test_accuracy = min([float(m.test_accuracy) for m in models])
+
+    if lowest_test_accuracy >= args.accuracy_to_make_c_quizzes:
         if c_quizzes is None:
             save_models(models, "naive")
 
@@ -1355,14 +1357,10 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         nb_gpus = len(gpus)
         nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
-        args = [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus]
-
-        # Ugly hack: Only one thread during the first epoch so that
-        # compilation of the model does not explode
-        if n_epoch == 0:
-            args = args[:1]
-
-        c_quizzes, agreements = multithread_execution(generate_ae_c_quizzes, args)
+        c_quizzes, agreements = multithread_execution(
+            generate_ae_c_quizzes,
+            [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
+        )
 
         save_c_quizzes_with_scores(
             models,
@@ -1378,6 +1376,16 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             solvable_only=True,
         )
 
+        u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, 1:]
+        i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices
+
+        save_c_quizzes_with_scores(
+            models,
+            c_quizzes[i][:256],
+            f"culture_c_quiz_{n_epoch:04d}_solvable_high_delta.png",
+            solvable_only=True,
+        )
+
         log_string(f"generated_c_quizzes {c_quizzes.size()=}")
 
         for model in models:
@@ -1395,17 +1403,13 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     # None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
 
-    args = [
-        (model, quiz_machine, n_epoch, c_quizzes, gpu)
-        for model, gpu in zip(weakest_models, gpus)
-    ]
-
-    # Ugly hack: Only one thread during the first epoch so that
-    # compilation of the model does not explode
-    if n_epoch == 0:
-        args = args[:1]
-
-    multithread_execution(one_ae_epoch, args)
+    multithread_execution(
+        one_ae_epoch,
+        [
+            (model, quiz_machine, n_epoch, c_quizzes, gpu)
+            for model, gpu in zip(weakest_models, gpus)
+        ],
+    )
 
     # --------------------------------------------------------------------