From a6cfa2a3f4d38bc10218b98b064798ed6f3899fe Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 7 Jul 2024 01:20:13 +0300 Subject: [PATCH] Update. --- grids.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/grids.py b/grids.py index ed72099..2d1293c 100755 --- a/grids.py +++ b/grids.py @@ -614,6 +614,60 @@ class Grids(problem.Problem): f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] + def task_ortho(self, A, f_A, B, f_B): + nb_rec = 3 + di, dj = torch.randint(3, (2,)) - 1 + o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]]) + m = torch.eye(2) + for _ in range(torch.randint(4, (1,))): + m = m @ o + if torch.rand(1) < 0.5: + m[0, :] = -m[0, :] + + ci, cj = (self.height - 1) / 2, (self.width - 1) / 2 + + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + X[...] = 0 + f_X[...] = 0 + + c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 + + for r in range(nb_rec): + while True: + i1, i2 = torch.randint(self.height - 2, (2,)) + 1 + j1, j2 = torch.randint(self.width - 2, (2,)) + 1 + if ( + i2 >= i1 + and j2 >= j1 + and max(i2 - i1, j2 - j1) >= 2 + and min(i2 - i1, j2 - j1) <= 3 + ): + break + X[i1 : i2 + 1, j1 : j2 + 1] = c[r] + + i1, j1, i2, j2 = i1 - ci, j1 - cj, i2 - ci, j2 - cj + + i1, j1 = m[0, 0] * i1 + m[0, 1] * j1, m[1, 0] * i1 + m[1, 1] * j1 + i2, j2 = m[0, 0] * i2 + m[0, 1] * j2, m[1, 0] * i2 + m[1, 1] * j2 + + i1, j1, i2, j2 = i1 + ci, j1 + cj, i2 + ci, j2 + cj + i1, i2 = i1.long() + di, i2.long() + di + j1, j2 = j1.long() + dj, j2.long() + dj + if i1 > i2: + i1, i2 = i2, i1 + if j1 > j2: + j1, j2 = j2, j1 + + f_X[i1 : i2 + 1, j1 : j2 + 1] = c[r] + + n = F.one_hot(X.flatten()).sum(dim=0)[1:] + if ( + n.sum() > self.height * self.width // 4 + and (n > 0).long().sum() == nb_rec + ): + break + def task_islands(self, A, f_A, B, f_B): pass @@ -703,7 +757,7 @@ if __name__ == "__main__": grids = Grids() # for t in grids.all_tasks(): - for t in [grids.task_islands]: + for t in [grids.task_ortho]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) -- 2.39.5