From f4d12501685fe9b46a75e3768115f86ea9b75fa6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 4 Jul 2024 04:48:26 +0300 Subject: [PATCH] Update. --- main.py | 3 ++ quizz_machine.py | 31 ++++++++++--- reasoning.py | 115 ++++++++++++++++++++++++++++------------------- 3 files changed, 99 insertions(+), 50 deletions(-) diff --git a/main.py b/main.py index a954af6..be0d8e0 100755 --- a/main.py +++ b/main.py @@ -249,8 +249,10 @@ if args.problem == "sky": nb_iterations=args.sky_nb_iterations, speed=args.sky_speed, ) + back_accuracy = False elif args.problem == "reasoning": problem = reasoning.Reasoning(device=device) + back_accuracy = True else: raise ValueError @@ -258,6 +260,7 @@ quizz_machine = quizz_machine.QuizzMachine( problem=problem, nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, + back_accuracy=back_accuracy, batch_size=args.physical_batch_size, result_dir=args.result_dir, logger=log_string, diff --git a/quizz_machine.py b/quizz_machine.py index 90f288e..6e57fb4 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -202,6 +202,7 @@ class QuizzMachine: problem, nb_train_samples, nb_test_samples, + back_accuracy, batch_size, result_dir, logger, @@ -215,6 +216,7 @@ class QuizzMachine: self.nb_token_values = v + 2 self.problem = problem + self.back_accuracy = back_accuracy self.batch_size = batch_size self.device = device self.logger = logger @@ -308,7 +310,6 @@ class QuizzMachine: self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 ): def compute_accuracy(input): - input = input[:nmax] ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) seq_logproba = torch.empty(input.size(0), device=self.device) @@ -325,18 +326,38 @@ class QuizzMachine: device=self.device, ) - nb_total = input.size(0) - nb_correct = (input == result).long().min(dim=1).values.sum() + if self.back_accuracy: + n_forward = input[:, 0] == self.token_forward + nb_total = input[n_forward].size(0) + nb_correct = ( + (input[n_forward] == result[n_forward]) + .long() + .min(dim=1) + .values.sum() + ) + + n_backward = input[:, 0] == self.token_backward + back_input = self.reverse_time(result[n_backward]) + if back_input.size(0) > 0: + back_input[:, 2 + self.prompt_len :] = input[ + n_backward, 2 + self.prompt_len : + ] + back_nb_total, back_nb_correct = compute_accuracy(back_input) + nb_total += back_nb_total + nb_correct += back_nb_correct + else: + nb_total = input.size(0) + nb_correct = (input == result).long().min(dim=1).values.sum() return nb_total, nb_correct - train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes) + train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax]) self.logger( f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" ) - test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes) + test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes[:nmax]) self.logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" diff --git a/reasoning.py b/reasoning.py index 57e8056..768c15c 100755 --- a/reasoning.py +++ b/reasoning.py @@ -42,6 +42,31 @@ class Reasoning(problem.Problem): ###################################################################### def frame2img(self, x, scale=15): + x = x.reshape(x.size(0), self.height, -1) + m = torch.logical_and(x >= 0, x < self.nb_token_values()).long() + x = self.colors[x * m].permute(0, 3, 1, 2) + s = x.shape + x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) + + x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 + x[:, :, torch.arange(0, x.size(2), scale), :] = 0 + x = x[:, :, 1:, 1:] + + for n in range(m.size(0)): + for i in range(m.size(1)): + for j in range(m.size(2)): + if m[n, i, j] == 0: + for k in range(2, scale - 2): + for l in [0, 1]: + x[n, :, i * scale + k, j * scale + k - l] = 0 + x[ + n, :, i * scale + scale - 1 - k, j * scale + k - l + ] = 0 + + return x + + def frame2img_(self, x, scale=15): x = x.reshape(x.size(0), self.height, -1) x = self.colors[x].permute(0, 3, 1, 2) s = x.shape @@ -173,14 +198,13 @@ class Reasoning(problem.Problem): # non-overlapping rectangles quickly, but made the generation of # 100k samples go from 1h50 with a lame pure python code to 3min30s # with this one. - def rec_coo(self, x, n, min_height=3, min_width=3): - K = 3 - N = 200 + def rec_coo(self, nb_rec, min_height=3, min_width=3): + nb_trials = 200 while True: v = ( ( - torch.rand(N * K, self.height + 1, device=self.device) + torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device) .sort(dim=-1) .indices < 2 @@ -192,7 +216,7 @@ class Reasoning(problem.Problem): h = ( ( - torch.rand(N * K, self.width + 1, device=self.device) + torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device) .sort(dim=-1) .indices < 2 @@ -207,10 +231,10 @@ class Reasoning(problem.Problem): ) v, h = v[i], h[i] - v = v[: v.size(0) - v.size(0) % K] - h = h[: h.size(0) - h.size(0) % K] - v = v.reshape(v.size(0) // K, K, -1) - h = h.reshape(h.size(0) // K, K, -1) + v = v[: v.size(0) - v.size(0) % nb_rec] + h = h[: h.size(0) - h.size(0) % nb_rec] + v = v.reshape(v.size(0) // nb_rec, nb_rec, -1) + h = h.reshape(h.size(0) // nb_rec, nb_rec, -1) r = v[:, :, :, None] * h[:, :, None, :] @@ -260,23 +284,23 @@ class Reasoning(problem.Problem): ###################################################################### def task_replace_color(self, A, f_A, B, f_B): - N = 3 - c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(X, N) - for n in range(N): + r = self.rec_coo(nb_rec) + for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n if n > 0 else -1] def task_move(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 - N = 3 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(X, N) - i1, j1, i2, j2 = r[N - 1] + r = self.rec_coo(nb_rec) + i1, j1, i2, j2 = r[nb_rec - 1] if ( i1 + di >= 0 and i2 + di < X.size(0) @@ -285,29 +309,29 @@ class Reasoning(problem.Problem): ): break - for n in range(N): + for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] - if n == N - 1: + if n == nb_rec - 1: f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c[n] else: f_X[i1:i2, j1:j2] = c[n] def task_grow(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 - N = 3 - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 direction = torch.randint(2, (1,)) for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(X, N) - i1, j1, i2, j2 = r[N - 1] + r = self.rec_coo(nb_rec) + i1, j1, i2, j2 = r[nb_rec - 1] if i1 + 3 < i2 and j1 + 3 < j2: break - for n in range(N): + for n in range(nb_rec): i1, j1, i2, j2 = r[n] - if n == N - 1: + if n == nb_rec - 1: if direction == 0: X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c[n] f_X[i1:i2, j1:j2] = c[n] @@ -320,12 +344,12 @@ class Reasoning(problem.Problem): def task_color_grow(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 - N = 3 - c = torch.randperm(len(self.colors) - 1)[: 2 * N] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1 direction = torch.randint(4, (1,)) for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(X, N) - for n in range(N): + r = self.rec_coo(nb_rec) + for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[2 * n] f_X[i1:i2, j1:j2] = c[2 * n] @@ -333,53 +357,54 @@ class Reasoning(problem.Problem): if direction == 0: i = (i1 + i2) // 2 X[i : i + 1, j1:j2] = c[2 * n + 1] - if n == N - 1: + if n == nb_rec - 1: f_X[i:i2, j1:j2] = c[2 * n + 1] else: f_X[i : i + 1, j1:j2] = c[2 * n + 1] elif direction == 1: i = (i1 + i2 - 1) // 2 X[i : i + 1, j1:j2] = c[2 * n + 1] - if n == N - 1: + if n == nb_rec - 1: f_X[i1 : i + 1, j1:j2] = c[2 * n + 1] else: f_X[i : i + 1, j1:j2] = c[2 * n + 1] elif direction == 2: j = (j1 + j2) // 2 X[i1:i2, j : j + 1] = c[2 * n + 1] - if n == N - 1: + if n == nb_rec - 1: f_X[i1:i2, j:j2] = c[2 * n + 1] else: f_X[i1:i2, j : j + 1] = c[2 * n + 1] elif direction == 3: j = (j1 + j2 - 1) // 2 X[i1:i2, j : j + 1] = c[2 * n + 1] - if n == N - 1: + if n == nb_rec - 1: f_X[i1:i2, j1 : j + 1] = c[2 * n + 1] else: f_X[i1:i2, j : j + 1] = c[2 * n + 1] def task_frame(self, A, f_A, B, f_B): - N = 3 - c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(X, N) - for n in range(N): + r = self.rec_coo(nb_rec) + for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] f_X[i1:i2, j1:j2] = c[n] - if n == N - 1: + if n == nb_rec - 1: f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0 def task_detect(self, A, f_A, B, f_B): - N = 3 - c = torch.randperm(len(self.colors) - 1)[: N + 1] + 1 + nb_rec = 3 + c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(X, N) - for n in range(N): + r = self.rec_coo(nb_rec) + for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] - f_X[i1, j1] = c[-1] + if n < nb_rec - 1: + f_X[i1, j1] = c[-1] ###################################################################### @@ -448,8 +473,8 @@ if __name__ == "__main__": reasoning.save_quizzes( "/tmp", "test", - prompts[:36], - answers[:36], + prompts[:64], + answers[:64], # You can add a bool to put a frame around the predicted parts # predicted_prompts, predicted_answers ) -- 2.39.5