From 3e1100f546e955430d87dd6808c8d148715bc50d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jul 2024 09:49:21 +0300 Subject: [PATCH] Update. --- reasoning.py | 101 +++++++++++++++++++++++++++------------------------ 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/reasoning.py b/reasoning.py index fb208b0..058c410 100755 --- a/reasoning.py +++ b/reasoning.py @@ -89,8 +89,14 @@ class Reasoning(problem.Problem): predicted_answers=None, nrow=4, ): - prompts = prompts.reshape(prompts.size(0), self.height, -1) - answers = answers.reshape(answers.size(0), self.height, -1) + S = self.height * self.width + As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width) + f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view( + -1, self.height, self.width + ) + Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width) + prompts = torch.cat([As, f_As, Bs], dim=2) + answers = answers.reshape(answers.size(0), self.height, self.width) if predicted_prompts is None: predicted_prompts = 255 @@ -415,53 +421,52 @@ class Reasoning(problem.Problem): if n < nb_rec - 1: f_X[i1, j1] = c[-1] + def contact(X, i, j, q): + nq, nq_diag = 0, 0 + no = 0 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j), + (i - 1, j + 1), + (i, j - 1), + (i, j + 1), + (i + 1, j - 1), + (i + 1, j), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] != 0 and X[ii, jj] != q: + no += 1 + + for ii, jj in [ + (i - 1, j - 1), + (i - 1, j + 1), + (i + 1, j - 1), + (i + 1, j + 1), + ]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q: + nq_diag += 1 + + for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]: + if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: + if X[ii, jj] == q: + nq += 1 + + return no, nq, nq_diag + def task_count(self, A, f_A, B, f_B): N = torch.randint(4, (1,)) + 2 c = torch.randperm(len(self.colors) - 1)[:N] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - - def contact(i, j, q): - nq, nq_diag = 0, 0 - no = 0 - - for ii, jj in [ - (i - 1, j - 1), - (i - 1, j), - (i - 1, j + 1), - (i, j - 1), - (i, j + 1), - (i + 1, j - 1), - (i + 1, j), - (i + 1, j + 1), - ]: - if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: - if X[ii, jj] != 0 and X[ii, jj] != q: - no += 1 - - for ii, jj in [ - (i - 1, j - 1), - (i - 1, j + 1), - (i + 1, j - 1), - (i + 1, j + 1), - ]: - if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: - if X[ii, jj] == q and X[i, jj] != q and X[ii, j] != q: - nq_diag += 1 - - for ii, jj in [(i - 1, j), (i, j - 1), (i, j + 1), (i + 1, j)]: - if ii >= 0 and ii < self.height and jj >= 0 and jj < self.width: - if X[ii, jj] == q: - nq += 1 - - return no, nq, nq_diag - nb = torch.zeros(N, dtype=torch.int64) q = torch.randint(N, (self.height * self.width,)) k = torch.randperm(self.height * self.width) for p in range(self.height * self.width): i, j = k[p] % self.height, k[p] // self.height - no, nq, nq_diag = contact(i, j, c[q[p]]) + no, nq, nq_diag = contact(X, i, j, c[q[p]]) if no == 0 and nq_diag == 0: if nq == 0: if nb[q[p]] < self.width: @@ -641,9 +646,9 @@ class Reasoning(problem.Problem): if tasks is None: tasks = self.all_tasks() - prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) - answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) - w = self.width + S = self.height * self.width + prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64) + answers = torch.zeros(nb, S, dtype=torch.int64) for prompt, answer in tqdm.tqdm( zip(prompts, answers), @@ -651,10 +656,10 @@ class Reasoning(problem.Problem): desc="world generation", total=prompts.size(0), ): - A = prompt[:, 0 * w : 1 * w] - f_A = prompt[:, 1 * w : 2 * w] - B = prompt[:, 2 * w : 3 * w] - f_B = answer + A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width) + f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width) + B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width) + f_B = answer.view(self.height, self.width) task = tasks[torch.randint(len(tasks), (1,))] task(A, f_A, B, f_B) @@ -686,14 +691,14 @@ class Reasoning(problem.Problem): if __name__ == "__main__": import time - nb = 4 + nb = 48 reasoning = Reasoning() for t in [reasoning.task_islands]: # reasoning.all_tasks(): print(t.__name__) prompts, answers = reasoning.generate_prompts_and_answers(nb, tasks=[t]) - reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=1) + reasoning.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) exit(0) -- 2.39.5