Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 18:27:50 +0000 (20:27 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 18:27:50 +0000 (20:27 +0200)
grids.py

index 9424496..6b2ea23 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -136,6 +136,7 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations):
 class Grids(problem.Problem):
     named_colors = [
         ("white", [255, 255, 255]),
+        # ("white", [224, 224, 224]),
         ("red", [255, 0, 0]),
         ("green", [0, 192, 0]),
         ("blue", [0, 0, 255]),
@@ -371,15 +372,16 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def grid2img(self, x, scale=15):
+    def grid2img(self, x, scale=15, grids=True):
         m = torch.logical_and(x >= 0, x < self.nb_colors).long()
         y = self.colors[x * m].permute(0, 3, 1, 2)
         s = y.shape
         y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale)
         y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
 
-        y[:, :, :, torch.arange(0, y.size(3), scale)] = 64
-        y[:, :, torch.arange(0, y.size(2), scale), :] = 64
+        if grids:
+            y[:, :, :, torch.arange(0, y.size(3), scale)] = 64
+            y[:, :, torch.arange(0, y.size(2), scale), :] = 64
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
@@ -394,15 +396,18 @@ class Grids(problem.Problem):
         return y
 
     def add_frame(self, img, colors, thickness):
-        result = img.new(
-            img.size(0),
-            img.size(1),
-            img.size(2) + 2 * thickness,
-            img.size(3) + 2 * thickness,
-        )
+        if thickness > 0:
+            result = img.new(
+                img.size(0),
+                img.size(1),
+                img.size(2) + 2 * thickness,
+                img.size(3) + 2 * thickness,
+            )
 
-        result[...] = colors[:, :, None, None]
-        result[:, :, thickness:-thickness, thickness:-thickness] = img
+            result[...] = colors[:, :, None, None]
+            result[:, :, thickness:-thickness, thickness:-thickness] = img
+        else:
+            result = img
 
         return result
 
@@ -462,22 +467,36 @@ class Grids(problem.Problem):
             device=quizzes.device,
         )
 
+        thickness = 1 if grids else 0
+
         if delta:
             u = (A != f_A).long()
-            img_delta_A = self.add_frame(self.grid2img(u), frame[None, :], thickness=1)
+            img_delta_A = self.add_frame(
+                self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+            )
             img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as(
                 img_delta_A
             )
             u = (B != f_B).long()
-            img_delta_B = self.add_frame(self.grid2img(u), frame[None, :], thickness=1)
+            img_delta_B = self.add_frame(
+                self.grid2img(u, grids=grids), frame[None, :], thickness=thickness
+            )
             img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as(
                 img_delta_B
             )
 
-        img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1)
-        img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1)
-        img_B = self.add_frame(self.grid2img(B), frame[None, :], thickness=1)
-        img_f_B = self.add_frame(self.grid2img(f_B), frame[None, :], thickness=1)
+        img_A = self.add_frame(
+            self.grid2img(A, grids=grids), frame[None, :], thickness=thickness
+        )
+        img_f_A = self.add_frame(
+            self.grid2img(f_A, grids=grids), frame[None, :], thickness=thickness
+        )
+        img_B = self.add_frame(
+            self.grid2img(B, grids=grids), frame[None, :], thickness=thickness
+        )
+        img_f_B = self.add_frame(
+            self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness
+        )
 
         # predicted_parts Nx4
         # correct_parts Nx4
@@ -1878,6 +1897,29 @@ if __name__ == "__main__":
 
     grids = Grids()
 
+    nb, nrow = 64, 4
+    # nb, nrow = 8, 2
+
+    # for t in grids.all_tasks:
+
+    for t in [
+        grids.task_replace_color,
+        grids.task_translate,
+        grids.task_grow,
+        grids.task_frame,
+    ]:
+        print(t.__name__)
+        w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
+        grids.save_quizzes_as_image(
+            "/tmp",
+            t.__name__ + ".png",
+            w_quizzes,
+            # grids=False
+            # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
+        )
+
+    exit(0)
+
     q = grids.text2quiz(
         """
 
@@ -1933,24 +1975,7 @@ vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa
 """
     )
 
-    grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1)
-
-    exit(0)
-
-    nb, nrow = 128, 4
-    # nb, nrow = 8, 2
-
-    # for t in grids.all_tasks:
-
-    for t in [grids.task_symmetry]:
-        print(t.__name__)
-        w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t])
-        grids.save_quizzes_as_image(
-            "/tmp",
-            t.__name__ + ".png",
-            w_quizzes,
-            comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))],
-        )
+    grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1, grids=False)
 
     exit(0)