From d408fc33d95cf9fba97b4db24aeddbe1a927bda6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 20:27:50 +0200 Subject: [PATCH] Update. --- grids.py | 95 +++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 60 insertions(+), 35 deletions(-) diff --git a/grids.py b/grids.py index 9424496..6b2ea23 100755 --- a/grids.py +++ b/grids.py @@ -136,6 +136,7 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations): class Grids(problem.Problem): named_colors = [ ("white", [255, 255, 255]), + # ("white", [224, 224, 224]), ("red", [255, 0, 0]), ("green", [0, 192, 0]), ("blue", [0, 0, 255]), @@ -371,15 +372,16 @@ class Grids(problem.Problem): ###################################################################### - def grid2img(self, x, scale=15): + def grid2img(self, x, scale=15, grids=True): m = torch.logical_and(x >= 0, x < self.nb_colors).long() y = self.colors[x * m].permute(0, 3, 1, 2) s = y.shape y = y[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - y[:, :, :, torch.arange(0, y.size(3), scale)] = 64 - y[:, :, torch.arange(0, y.size(2), scale), :] = 64 + if grids: + y[:, :, :, torch.arange(0, y.size(3), scale)] = 64 + y[:, :, torch.arange(0, y.size(2), scale), :] = 64 for n in range(m.size(0)): for i in range(m.size(1)): @@ -394,15 +396,18 @@ class Grids(problem.Problem): return y def add_frame(self, img, colors, thickness): - result = img.new( - img.size(0), - img.size(1), - img.size(2) + 2 * thickness, - img.size(3) + 2 * thickness, - ) + if thickness > 0: + result = img.new( + img.size(0), + img.size(1), + img.size(2) + 2 * thickness, + img.size(3) + 2 * thickness, + ) - result[...] = colors[:, :, None, None] - result[:, :, thickness:-thickness, thickness:-thickness] = img + result[...] = colors[:, :, None, None] + result[:, :, thickness:-thickness, thickness:-thickness] = img + else: + result = img return result @@ -462,22 +467,36 @@ class Grids(problem.Problem): device=quizzes.device, ) + thickness = 1 if grids else 0 + if delta: u = (A != f_A).long() - img_delta_A = self.add_frame(self.grid2img(u), frame[None, :], thickness=1) + img_delta_A = self.add_frame( + self.grid2img(u, grids=grids), frame[None, :], thickness=thickness + ) img_delta_A = img_delta_A.min(dim=1, keepdim=True).values.expand_as( img_delta_A ) u = (B != f_B).long() - img_delta_B = self.add_frame(self.grid2img(u), frame[None, :], thickness=1) + img_delta_B = self.add_frame( + self.grid2img(u, grids=grids), frame[None, :], thickness=thickness + ) img_delta_B = img_delta_B.min(dim=1, keepdim=True).values.expand_as( img_delta_B ) - img_A = self.add_frame(self.grid2img(A), frame[None, :], thickness=1) - img_f_A = self.add_frame(self.grid2img(f_A), frame[None, :], thickness=1) - img_B = self.add_frame(self.grid2img(B), frame[None, :], thickness=1) - img_f_B = self.add_frame(self.grid2img(f_B), frame[None, :], thickness=1) + img_A = self.add_frame( + self.grid2img(A, grids=grids), frame[None, :], thickness=thickness + ) + img_f_A = self.add_frame( + self.grid2img(f_A, grids=grids), frame[None, :], thickness=thickness + ) + img_B = self.add_frame( + self.grid2img(B, grids=grids), frame[None, :], thickness=thickness + ) + img_f_B = self.add_frame( + self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness + ) # predicted_parts Nx4 # correct_parts Nx4 @@ -1878,6 +1897,29 @@ if __name__ == "__main__": grids = Grids() + nb, nrow = 64, 4 + # nb, nrow = 8, 2 + + # for t in grids.all_tasks: + + for t in [ + grids.task_replace_color, + grids.task_translate, + grids.task_grow, + grids.task_frame, + ]: + print(t.__name__) + w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) + grids.save_quizzes_as_image( + "/tmp", + t.__name__ + ".png", + w_quizzes, + # grids=False + # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))], + ) + + exit(0) + q = grids.text2quiz( """ @@ -1933,24 +1975,7 @@ vvvv.aaaaa rrrr.aaaaa .......aaa .......aaa """ ) - grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1) - - exit(0) - - nb, nrow = 128, 4 - # nb, nrow = 8, 2 - - # for t in grids.all_tasks: - - for t in [grids.task_symmetry]: - print(t.__name__) - w_quizzes = grids.generate_w_quizzes_(nb, tasks=[t]) - grids.save_quizzes_as_image( - "/tmp", - t.__name__ + ".png", - w_quizzes, - comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))], - ) + grids.save_quizzes_as_image("/tmp", "test.png", q, nrow=1, grids=False) exit(0) -- 2.39.5