From a2e623102ecd20491a8ba89bd119bffa0b34da1f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 14 Jul 2024 10:35:29 +0200 Subject: [PATCH] Update. --- grids.py | 110 ++++++++++++++++++++++++++----------------------------- 1 file changed, 52 insertions(+), 58 deletions(-) diff --git a/grids.py b/grids.py index 8d144cf..aa21543 100755 --- a/grids.py +++ b/grids.py @@ -67,26 +67,31 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations): M = F.conv2d(Z[:, None, :, :], w, padding=1) M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1) M = ((M[:, 0] == 0) & (Z == 0)).long() + Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None] M = M * torch.rand(M.size()) M = M.flatten(1) M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1)) - U += M + U += M * Q for _ in range(nb_iterations): M = F.conv2d(Z[:, None, :, :], w, padding=1) M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1) M = ((M[:, 1] >= 0) & (Z == 0)).long() + Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None] M = M * torch.rand(M.size()) M = M.flatten(1) M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1)) U = Z.flatten(1) - U += M + U += M * Q M = Z.clone() Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2)) - for _ in range(100): + while True: + W = Z.clone() Z = F.max_pool2d(Z, 3, 1, 1) * M + if Z.equal(W): + break Z = Z.long() U = Z.flatten(1) @@ -609,61 +614,50 @@ class Grids(problem.Problem): return no, nq, nq_diag def task_count(self, A, f_A, B, f_B): - N = torch.randint(4, (1,)).item() + 2 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 - - for X, f_X in [(A, f_A), (B, f_B)]: - l_q = torch.randperm(self.height * self.width)[ - : self.height * self.width // 20 - ] - l_d = torch.randint(N, l_q.size()) - nb = torch.zeros(N, dtype=torch.int64) - - for q, e in zip(l_q, l_d): - d = c[e] - i, j = q % self.height, q // self.height - if ( - nb[e] < self.width - and X[max(0, i - 1) : i + 2, max(0, j - 1) : j + 2] == 0 - ).all(): - X[i, j] = d - nb[e] += 1 - - l_q = torch.randperm((self.height - 2) * (self.width - 2))[ - : self.height * self.width // 2 - ] - l_d = torch.randint(N, l_q.size()) - for q, e in zip(l_q, l_d): - d = c[e] - i, j = q % (self.height - 2) + 1, q // (self.height - 2) + 1 - a1, a2, a3 = X[i - 1, j - 1 : j + 2] - a8, a4 = X[i, j - 1], X[i, j + 1] - a7, a6, a5 = X[i + 1, j - 1 : j + 2] - if ( - X[i, j] == 0 - and nb[e] < self.width - and (a2 == 0 or a2 == d) - and (a4 == 0 or a4 == d) - and (a6 == 0 or a6 == d) - and (a8 == 0 or a8 == d) - and (a1 == 0 or a2 == d or a8 == d) - and (a3 == 0 or a4 == d or a2 == d) - and (a5 == 0 or a6 == d or a4 == d) - and (a7 == 0 or a8 == d or a6 == d) - ): - o = ( - (a2 != 0).long() - + (a4 != 0).long() - + (a6 != 0).long() - + (a8 != 0).long() + while True: + error = False + + N = torch.randint(5, (1,)).item() + 1 + c = torch.zeros(N + 1) + c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + if not hasattr(self, "cache_count") or len(self.cache_count) == 0: + self.cache_count = list( + grow_islands( + 1000, + self.height, + self.width, + nb_seeds=self.height * self.width // 9, + nb_iterations=self.height * self.width // 20, + ) ) - if o <= 1: - X[i, j] = d - nb[e] += 1 - o - for e in range(N): - for j in range(nb[e]): - f_X[e, j] = c[e] + X[...] = self.cache_count.pop() + + k = (X.max() + 1 + (c.size(0) - 1)).item() + V = torch.arange(k) // (c.size(0) - 1) + V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % ( + c.size(0) - 1 + ) + 1 + V[0] = 0 + X[...] = c[V[X]] + + if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1: + f_X[...] = 0 + for e in range(1, N + 1): + for j in range((X == c[e]).sum() + 1): + if j < self.width: + f_X[e - 1, j] = c[e] + else: + error = True + break + else: + error = True + break + + if not error: + break # @torch.compile def task_trajectory(self, A, f_A, B, f_B): @@ -1214,12 +1208,12 @@ if __name__ == "__main__": "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow ) - exit(0) + # exit(0) nb = 1000 # for t in grids.all_tasks: - for t in [grids.task_islands]: + for t in [grids.task_count]: start_time = time.perf_counter() prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) delay = time.perf_counter() - start_time -- 2.39.5