From 752a2712a77f0a42091b24e40e6d210f6e2cf110 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 16 Jul 2024 19:53:11 +0200 Subject: [PATCH] Update. --- grids.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++++- main.py | 18 +++++++++----- quiz_machine.py | 14 +++++++---- 3 files changed, 85 insertions(+), 11 deletions(-) diff --git a/grids.py b/grids.py index a115f93..e4d831c 100755 --- a/grids.py +++ b/grids.py @@ -1069,6 +1069,68 @@ class Grids(problem.Problem): X[i, j] = c[1] f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0] + # @torch.compile + def task_stack(self, A, f_A, B, f_B): + N = 5 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + i1, j1, i2, j2 = ( + self.height // 2 - 1, + self.width // 2 - 1, + self.height // 2 + 1, + self.width // 2 + 1, + ) + op = torch.tensor((0, 1, 2, 3) * 4) + op = op[torch.randperm(op.size(0))[:9]] + for q in range(op.size(0)): + u = 3 * (q // 3) + v = 3 * (q % 3) + d = c[torch.randint(N, (1,)).item()] + # X[u+1,v+1]=d + if op[q] == 0: # right + X[u : u + 3, v + 2] = d + elif op[q] == 1: # let + X[u : u + 3, v] = d + elif op[q] == 2: # bottom + X[u + 2, v : v + 3] = d + elif op[q] == 3: # top + X[u, v : v + 3] = d + + if q == 0: + f_X[i1:i2, j1:j2] = d + elif op[q] == 0: # right + f_X[i1:i2, j2] = d + j2 += 1 + elif op[q] == 1: # let + j1 -= 1 + f_X[i1:i2, j1] = d + elif op[q] == 2: # bottom + f_X[i2, j1:j2] = d + i2 += 1 + elif op[q] == 3: # top + i1 -= 1 + f_X[i1, j1:j2] = d + + def randint(self, *m): + m = torch.tensor(m) + return (torch.rand(m.size()) * m).long() + + def task_matrices(self, A, f_A, B, f_B): + N = 6 + c = torch.randperm(len(self.colors) - 1)[:N] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + M1 = torch.randint(2, (5, 5)) + M2 = torch.randint(2, (5, 5)) + P = M1 @ M2 + for i in range(5): + for j in range(5): + X[i, j] = c[M1[i, j]] + X[i, j + 5] = c[M2[i, j]] + f_X[i, j] = c[M1[i, j]] + f_X[i, j + 5] = c[M2[i, j]] + f_X[i + 5, j + 5] = c[P[i, j]] + ###################################################################### def trivial_prompts_and_answers(self, prompts, answers): @@ -1159,7 +1221,7 @@ if __name__ == "__main__": # nb, nrow = 8, 2 # for t in grids.all_tasks: - for t in [grids.task_count]: + for t in [grids.task_matrices]: print(t.__name__) prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) grids.save_quiz_illustrations( diff --git a/main.py b/main.py index b149e62..957e95a 100755 --- a/main.py +++ b/main.py @@ -88,7 +88,7 @@ parser.add_argument("--proba_understands", type=float, default=0.9) parser.add_argument("--proba_not_understands", type=float, default=0.5) -# parser.add_argument("--generation_temperature", type=float, default=2) +parser.add_argument("--generation_temperature", type=float, default=2) parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") @@ -410,25 +410,31 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 start_time = time.perf_counter() + nb_validated = torch.zeros(len(models)) + while nb_validated < nb_to_create: - model_for_generation = models[torch.randint(len(models), (1,))] + # We balance the number of quizzes per model + + model_for_generation = models[nb_validated.argmin()] c_quizzes = quiz_machine.generate_c_quizzes( nb_to_generate_per_iteration, model_for_generation=model_for_generation, forward_only=args.forward_only, + generation_temperature=args.generation_temperature, ) c_quizzes = keep_good_quizzes(models, c_quizzes) - nb_validated += c_quizzes.size(0) + nb_validated[model.id] += c_quizzes.size(0) + total_nb_validated = nb_validated.sum() recorded.append(c_quizzes) duration = time.perf_counter() - start_time - if nb_validated > 0 and nb_validated < nb_to_create: - d = (nb_to_create - nb_validated) * duration / nb_validated + if total_nb_validated > 0 and total_nb_validated < nb_to_create: + d = (nb_to_create - total_nb_validated) * duration / total_nb_validated e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime( "%a %H:%M" ) @@ -436,7 +442,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 e = "???" log_string( - f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create} (finishes {e})" + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {total_nb_validated} / {nb_to_create} (finishes {e})" ) validated_quizzes = torch.cat(recorded, dim=0) diff --git a/quiz_machine.py b/quiz_machine.py index 008e435..0b84b36 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -541,7 +541,13 @@ class QuizMachine: ############################################################### - def generate_c_quizzes(self, nb, model_for_generation, forward_only=False): + def generate_c_quizzes( + self, + nb, + model_for_generation, + forward_only=False, + generation_temperature=1.0 + ): c_quizzes = torch.empty( nb, self.prompt_len + self.answer_len + 2, @@ -561,7 +567,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes, first=True), seq_logproba=seq_logproba, - temperature=1.0, + temperature=generation_temperature, deterministic_synthesis=False, device=self.device, ) @@ -572,7 +578,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=1, + temperature=1.0 deterministic_synthesis=False, device=self.device, ) @@ -587,7 +593,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes, first=True), seq_logproba=seq_logproba, - temperature=1.0, + temperature=generation_temperature, deterministic_synthesis=False, device=self.device, ) -- 2.39.5