Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 10:38:55 +0000 (12:38 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 10:38:55 +0000 (12:38 +0200)
grids.py
main.py

index 98a0581..9441811 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -359,6 +359,7 @@ class Grids(problem.Problem):
         comment_height=48,
         nrow=4,
         margin=8,
+        delta=False,
     ):
         quizzes = quizzes.to("cpu")
 
@@ -389,6 +390,10 @@ class Grids(problem.Problem):
             device=quizzes.device,
         )
 
+        if delta:
+            u = (B != f_B).long()
+            img_delta = self.add_frame(self.grid2img(u), frame[None, :], thickness=1)
+
         img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1)
         img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1)
         img_B = self.add_frame(self.grid2img(B), frame[None, :], thickness=1)
@@ -428,7 +433,12 @@ class Grids(problem.Problem):
         img_B = self.add_frame(img_B, white[None, :], thickness=2)
         img_f_B = self.add_frame(img_f_B, white[None, :], thickness=2)
 
-        img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
+        if delta:
+            img_delta = self.add_frame(img_delta, colors[:, 0], thickness=8)
+            img_delta = self.add_frame(img_delta, white[None, :], thickness=2)
+            img = torch.cat([img_A, img_f_A, img_B, img_f_B, img_delta], dim=3)
+        else:
+            img = torch.cat([img_A, img_f_A, img_B, img_f_B], dim=3)
 
         if comments is not None:
             comment_img = [text_img(comment_height, img.size(3), t) for t in comments]
diff --git a/main.py b/main.py
index 58d6287..2731f25 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1357,8 +1357,8 @@ def generate_ae_c_quizzes(models, local_device=main_device):
 
     duration_max = 4 * 3600
 
-    wanted_nb = 128  # 0000
-    nb_to_save = 128
+    wanted_nb = 16  # 0000
+    nb_to_save = 16
 
     with torch.autograd.no_grad():
         records = [[] for _ in criteria]
@@ -1394,9 +1394,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
 
         duration = time.perf_counter() - start_time
 
-        log_string(
-            f"generate_c_quizz_generation_speed {int(3600 * wanted_nb / duration)}/h"
-        )
+        log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
 
         for n, u in enumerate(records):
             quizzes = torch.cat(u, dim=0)[:nb_to_save]
@@ -1418,6 +1416,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
                 # predicted_parts=predicted_parts,
                 # correct_parts=correct_parts,
                 comments=comments,
+                delta=True,
             )
 
             log_string(f"wrote {filename}")