From c9e638c87ad1f2a8b8d5c666a7588ee49c2c995e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 19 Jul 2024 08:10:58 +0200 Subject: [PATCH] Update. --- grids.py | 43 ++++++++---- main.py | 22 +++--- quiz_machine.py | 180 ++++++++++++++++++++++-------------------------- 3 files changed, 122 insertions(+), 123 deletions(-) diff --git a/grids.py b/grids.py index 4f07d70..c2ff0d1 100755 --- a/grids.py +++ b/grids.py @@ -126,6 +126,8 @@ class Grids(problem.Problem): tasks=None, ): self.colors = torch.tensor([c for _, c in self.named_colors]) + self.token_forward = len(self.colors) + self.token_backward = self.token_forward + 1 self.height = 10 self.width = 10 self.cache_rec_coo = {} @@ -157,7 +159,7 @@ class Grids(problem.Problem): def frame2img(self, x, scale=15): x = x.reshape(x.size(0), self.height, -1) - m = torch.logical_and(x >= 0, x < self.nb_token_values()).long() + m = torch.logical_and(x >= 0, x < len(self.colors)).long() x = self.colors[x * m].permute(0, 3, 1, 2) s = x.shape x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) @@ -192,13 +194,19 @@ class Grids(problem.Problem): margin=8, ): S = self.height * self.width - As = prompts[:, 0 * (S + 1) : 0 * (S + 1) + S].view(-1, self.height, self.width) - f_As = prompts[:, 1 * (S + 1) : 1 * (S + 1) + S].view( + As = prompts[:, 0 * (S + 1) + 1 : 0 * (S + 1) + S + 1].view( + -1, self.height, self.width + ) + f_As = prompts[:, 1 * (S + 1) + 1 : 1 * (S + 1) + S + 1].view( + -1, self.height, self.width + ) + Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view( -1, self.height, self.width ) - Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S].view(-1, self.height, self.width) prompts = torch.cat([As, f_As, Bs], dim=2) - answers = answers.reshape(answers.size(0), self.height, self.width) + answers = answers[:, 1 : S + 1].reshape( + answers.size(0), self.height, self.width + ) if predicted_prompts is None: predicted_prompts = 255 @@ -307,7 +315,7 @@ class Grids(problem.Problem): ###################################################################### def nb_token_values(self): - return len(self.colors) + return len(self.colors) + 2 # @torch.compile def rec_coo( @@ -1180,8 +1188,9 @@ class Grids(problem.Problem): def trivial_prompts_and_answers(self, prompts, answers): S = self.height * self.width - Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S] - f_Bs = answers + Bs = prompts[:, 2 * (S + 1) + 1 : 2 * (S + 1) + S + 1] + f_Bs = answers[:, 1:] + print(f"{prompts.size()=} {answers.size()=} {Bs.size()=} {f_Bs.size()=}") return (Bs == f_Bs).long().min(dim=-1).values > 0 def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False): @@ -1189,8 +1198,8 @@ class Grids(problem.Problem): tasks = self.all_tasks S = self.height * self.width - prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64) - answers = torch.zeros(nb, S, dtype=torch.int64) + prompts = torch.full((nb, 3 * S + 3), self.token_forward) + answers = torch.full((nb, S + 1), self.token_forward) bunch = zip(prompts, answers) @@ -1203,10 +1212,16 @@ class Grids(problem.Problem): ) for prompt, answer in bunch: - A = prompt[0 * (S + 1) : 0 * (S + 1) + S].view(self.height, self.width) - f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width) - B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width) - f_B = answer.view(self.height, self.width) + A = prompt[0 * (S + 1) + 1 : 0 * (S + 1) + 1 + S].view( + self.height, self.width + ) + f_A = prompt[1 * (S + 1) + 1 : 1 * (S + 1) + 1 + S].view( + self.height, self.width + ) + B = prompt[2 * (S + 1) + 1 : 2 * (S + 1) + S + 1].view( + self.height, self.width + ) + f_B = answer[1 : S + 1].view(self.height, self.width) task = tasks[torch.randint(len(tasks), (1,)).item()] task(A, f_A, B, f_B) diff --git a/main.py b/main.py index 0d0d373..ab87b56 100755 --- a/main.py +++ b/main.py @@ -90,11 +90,13 @@ parser.add_argument("--proba_understands", type=float, default=0.9) parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--generation_temperature", type=float, default=1.5) +parser.add_argument("--temperature_hot", type=float, default=1.5) + +parser.add_argument("--temperature_cold", type=float, default=0.75) parser.add_argument("--c_quiz_validation_mode", type=str, default="predict") -parser.add_argument("--forward_only", action="store_true", default=False) +parser.add_argument("--p2a_only", action="store_true", default=False) parser.add_argument("--dirty_debug", action="store_true", default=False) @@ -374,8 +376,8 @@ def one_epoch(model, quiz_machine, local_device=main_device): acc_train_loss += loss.item() * input.size(0) loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) - n_forward = input[:, 0] == quiz_machine.token_forward - to_store = from_w & n_forward.to("cpu") + n_p2a = input[:, 0] == quiz_machine.token_p2a + to_store = from_w & n_p2a.to("cpu") if to_store.any(): hard_w_quizzes.append( (input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu")) @@ -454,13 +456,13 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # We balance the number of quizzes per model model_for_generation = sorted(models, key=lambda m: nb_validated[m.id])[0] - print(nb_validated, "using", model_for_generation.id) c_quizzes = quiz_machine.generate_c_quizzes( nb_to_generate_per_iteration, model_for_generation=model_for_generation, - forward_only=args.forward_only, - generation_temperature=args.generation_temperature, + p2a_only=args.p2a_only, + temperature_hot=args.temperature_hot, + temperature_cold=args.temperature_cold, ) c_quizzes = keep_good_quizzes(models, c_quizzes) @@ -536,14 +538,14 @@ for k in range(args.nb_gpts): model=model, nb=args.nb_train_samples, for_train=True, - forward_only=args.forward_only, + p2a_only=args.p2a_only, ) quiz_machine.create_w_quizzes( model=model, nb=args.nb_test_samples, for_train=False, - forward_only=args.forward_only, + p2a_only=args.p2a_only, ) models.append(model) @@ -673,7 +675,7 @@ for n_epoch in range(args.nb_epochs): quiz_machine.renew_w_quizzes( model=model, for_train=True, - forward_only=args.forward_only, + p2a_only=args.p2a_only, ) if args.log_command is not None: diff --git a/quiz_machine.py b/quiz_machine.py index 032305a..51c3f08 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -138,83 +138,72 @@ def masked_inplace_autoregression( class QuizMachine: - def indices_forward_and_backward(self, quizzes): - i_forward = quizzes[:, 0] == self.token_forward - j_forward = quizzes[:, 1 + self.prompt_len] == self.token_forward - i_backward = quizzes[:, 0] == self.token_backward - j_backward = quizzes[:, 1 + self.answer_len] == self.token_backward + def indices_p2a_and_a2p(self, quizzes): + i_p2a = quizzes[:, 0] == self.problem.token_forward + j_p2a = quizzes[:, self.prompt_len] == self.problem.token_forward + i_a2p = quizzes[:, 0] == self.problem.token_backward + j_a2p = quizzes[:, self.answer_len] == self.problem.token_backward assert torch.logical_or( - torch.logical_and(i_forward, j_forward), - torch.logical_and(i_backward, j_backward), + torch.logical_and(i_p2a, j_p2a), + torch.logical_and(i_a2p, j_a2p), ).all() - return i_forward, i_backward + return i_p2a, i_a2p def non_trivial(self, quizzes): quizzes = quizzes.clone() - n_forward = quizzes[quizzes[:, 0] == self.token_forward] - n_backward = quizzes[:, 0] == self.token_backward - backward = quizzes[n_backward] - quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) + n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward] + n_a2p = quizzes[:, 0] == self.problem.token_backward + a2p = quizzes[n_a2p] + quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p]) return torch.logical_not( self.problem.trivial_prompts_and_answers( - quizzes[:, 1 : 1 + self.prompt_len], - quizzes[:, 2 + self.prompt_len :], + quizzes[:, : self.prompt_len], quizzes[:, self.prompt_len :] ) ) - def reverse_time(self, quizzes): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) + def p_a_flip(self, quizzes): + i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes) - forward_to_backward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len], - quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1], - quizzes[:, 1 : 1 + self.prompt_len], - ], + p2a_to_a2p = torch.cat( + [quizzes[:, self.prompt_len :], quizzes[:, : self.prompt_len]], dim=1, ) - forward_to_backward[:, 0] = self.token_backward - forward_to_backward[:, 1 + self.answer_len] = self.token_backward + p2a_to_a2p[:, 0] = self.problem.token_backward + p2a_to_a2p[:, self.answer_len] = self.problem.token_backward - backward_to_forward = torch.cat( - [ - quizzes[:, 0:1], - quizzes[:, 2 + self.answer_len :], - quizzes[:, 1 + self.answer_len : 2 + self.answer_len], - quizzes[:, 1 : 1 + self.answer_len], - ], + a2p_to_p2a = torch.cat( + [quizzes[:, self.answer_len :], quizzes[:, : self.answer_len]], dim=1, ) - backward_to_forward[:, 0] = self.token_forward - backward_to_forward[:, 1 + self.prompt_len] = self.token_forward + a2p_to_p2a[:, 0] = self.problem.token_forward + a2p_to_p2a[:, self.prompt_len] = self.problem.token_forward - m = i_forward.long()[:, None] + m = i_p2a.long()[:, None] - return m * forward_to_backward + (1 - m) * backward_to_forward + return m * p2a_to_a2p + (1 - m) * a2p_to_p2a - def reverse_random_half_in_place(self, quizzes): + def p_a_flip_half_in_place(self, quizzes): i = torch.rand(quizzes.size(0)) < 0.5 if i.any(): - quizzes[i] = self.reverse_time(quizzes[i]) + quizzes[i] = self.p_a_flip(quizzes[i]) def make_ar_mask(self, quizzes, first=False): - i_forward, i_backward = self.indices_forward_and_backward(quizzes) + i_p2a, i_a2p = self.indices_p2a_and_a2p(quizzes) t = torch.arange(quizzes.size(1), device=quizzes.device) if first: - m_forward = (t >= 1).long() * (t < 1 + self.prompt_len).long() - m_backward = (t >= 1).long() * (t < 1 + self.answer_len).long() + m_p2a = (t >= 1).long() * (t < self.prompt_len).long() + m_a2p = (t >= 1).long() * (t < self.answer_len).long() else: - m_forward = (t >= 2 + self.prompt_len).long() - m_backward = (t >= 2 + self.answer_len).long() + m_p2a = (t >= 1 + self.prompt_len).long() + m_a2p = (t >= 1 + self.answer_len).long() - m = i_forward.long()[:, None] + m = i_p2a.long()[:, None] - return m * m_forward + (1 - m) * m_backward + return m * m_p2a + (1 - m) * m_a2p def generate_token_sequences(self, nb): prompts, answers = self.problem.generate_prompts_and_answers(nb) @@ -230,14 +219,7 @@ class QuizMachine: result = [] for prompt, answer in zip(prompts, answers): - a = [ - torch.tensor([self.token_forward]), - prompt, - torch.tensor([self.token_forward]), - answer, - ] - - result.append(torch.cat(a, dim=0)[None, :]) + result.append(torch.cat([prompt, answer], dim=0)[None, :]) return torch.cat(result, dim=0) @@ -252,10 +234,7 @@ class QuizMachine: ): super().__init__() - v = problem.nb_token_values() - self.token_forward = v - self.token_backward = v + 1 - self.nb_token_values = v + 2 + self.nb_token_values = problem.nb_token_values() self.problem = problem self.back_accuracy = back_accuracy @@ -278,14 +257,14 @@ class QuizMachine: show_part_to_predict=True, ): quizzes = quizzes.clone().to("cpu") - n_forward = quizzes[quizzes[:, 0] == self.token_forward] - n_backward = quizzes[:, 0] == self.token_backward - backward = quizzes[n_backward] - assert n_forward.size(0) + backward.size(0) == quizzes.size(0) - quizzes[n_backward] = self.reverse_time(quizzes[n_backward]) + n_p2a = quizzes[quizzes[:, 0] == self.problem.token_forward] + n_a2p = quizzes[:, 0] == self.problem.token_backward + a2p = quizzes[n_a2p] + assert n_p2a.size(0) + a2p.size(0) == quizzes.size(0) + quizzes[n_a2p] = self.p_a_flip(quizzes[n_a2p]) if show_part_to_predict: - predicted_prompts = n_backward.long() + predicted_prompts = n_a2p.long() predicted_answers = 1 - predicted_prompts if mistakes is not None: # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct @@ -371,29 +350,27 @@ class QuizMachine: correct = torch.empty(input.size(0), dtype=torch.int64, device=input.device) - n_forward = input[:, 0] == self.token_forward - n_backward = input[:, 0] == self.token_backward + n_p2a = input[:, 0] == self.problem.token_forward + n_a2p = input[:, 0] == self.problem.token_backward - correct[n_forward] = ( - (input[n_forward] == result[n_forward]).long().min(dim=1).values - ) + correct[n_p2a] = (input[n_p2a] == result[n_p2a]).long().min(dim=1).values - if self.back_accuracy and n_backward.any(): + if self.back_accuracy and n_a2p.any(): # accuracy of B->A*->B*=B instead of B->A*=A - back_input = self.reverse_time(result[n_backward]) + back_input = self.p_a_flip(result[n_a2p]) back_input[:, 2 + self.prompt_len :] = input[ - n_backward, 1 : 1 + self.answer_len + n_a2p, 1 : 1 + self.answer_len ] - _, correct[n_backward] = compute_accuracy(back_input) + _, correct[n_a2p] = compute_accuracy(back_input) if log_prefix is not None: - forward_nb_correct = correct[n_forward].sum() - forward_nb_total = correct[n_forward].size(0) - backward_nb_correct = correct[n_backward].sum() - backward_nb_total = correct[n_backward].size(0) + p2a_nb_correct = correct[n_p2a].sum() + p2a_nb_total = correct[n_p2a].size(0) + a2p_nb_correct = correct[n_a2p].sum() + a2p_nb_total = correct[n_a2p].size(0) self.logger( - f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}" + f"{log_prefix}_accuracy {n_epoch} model {model.id} p2a {p2a_nb_correct} / {p2a_nb_total} a2p {a2p_nb_correct} / {a2p_nb_total}" ) return result, correct @@ -402,11 +379,11 @@ class QuizMachine: model.test_w_quizzes[:2000], log_prefix="test" ) - n_test_forward = model.test_w_quizzes[:2000, 0] == self.token_forward + n_test_p2a = model.test_w_quizzes[:2000, 0] == self.problem.token_forward - forward_test_correct = test_correct[n_test_forward] + p2a_test_correct = test_correct[n_test_p2a] - main_test_accuracy = forward_test_correct.sum() / forward_test_correct.size(0) + main_test_accuracy = p2a_test_correct.sum() / p2a_test_correct.size(0) ############################## @@ -421,11 +398,11 @@ class QuizMachine: ###################################################################### - def create_w_quizzes(self, model, nb, for_train=True, forward_only=False): + def create_w_quizzes(self, model, nb, for_train=True, p2a_only=False): input = self.generate_token_sequences(nb) - if not forward_only: - self.reverse_random_half_in_place(input) + if not p2a_only: + self.p_a_flip_half_in_place(input) if for_train: model.train_w_quizzes = input @@ -434,7 +411,7 @@ class QuizMachine: ###################################################################### - def renew_w_quizzes(self, model, for_train=True, forward_only=False): + def renew_w_quizzes(self, model, for_train=True, p2a_only=False): input = model.train_w_quizzes if for_train else model.test_w_quizzes if for_train and hasattr(model, "hard_w_quizzes"): @@ -458,8 +435,8 @@ class QuizMachine: else: input[...] = self.generate_token_sequences(input.size(0)) - if not forward_only: - self.reverse_random_half_in_place(input) + if not p2a_only: + self.p_a_flip_half_in_place(input) ###################################################################### @@ -553,20 +530,25 @@ class QuizMachine: ############################################################### def generate_c_quizzes( - self, nb, model_for_generation, forward_only=False, generation_temperature=1.0 + self, + nb, + model_for_generation, + p2a_only=False, + temperature_hot=1.0, + temperature_cold=1.0, ): c_quizzes = torch.empty( nb, - self.prompt_len + self.answer_len + 2, + self.prompt_len + self.answer_len, device=self.device, dtype=torch.int64, ) seq_logproba = torch.zeros(nb, device=self.device) - if forward_only: - c_quizzes[:, 0] = self.token_forward - c_quizzes[:, 1 + self.prompt_len] = self.token_forward + if p2a_only: + c_quizzes[:, 0] = self.problem.token_forward + c_quizzes[:, self.prompt_len] = self.problem.token_forward masked_inplace_autoregression( model=model_for_generation, @@ -574,7 +556,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes, first=True), seq_logproba=seq_logproba, - temperature=generation_temperature, + temperature=temperature_hot, deterministic_synthesis=False, device=self.device, ) @@ -585,14 +567,14 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=1.0, + temperature=temperature_cold, deterministic_synthesis=False, device=self.device, ) else: - c_quizzes[:, 0] = self.token_backward - c_quizzes[:, 1 + self.answer_len] = self.token_backward + c_quizzes[:, 0] = self.problem.token_backward + c_quizzes[:, self.answer_len] = self.problem.token_backward masked_inplace_autoregression( model=model_for_generation, @@ -600,7 +582,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes, first=True), seq_logproba=seq_logproba, - temperature=generation_temperature, + temperature=temperature_hot, deterministic_synthesis=False, device=self.device, ) @@ -611,12 +593,12 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=0.75, + temperature=temperature_cold, deterministic_synthesis=False, device=self.device, ) - c_quizzes = self.reverse_time(c_quizzes) + c_quizzes = self.p_a_flip(c_quizzes) masked_inplace_autoregression( model=model_for_generation, @@ -624,7 +606,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_ar_mask(c_quizzes), seq_logproba=seq_logproba, - temperature=0.75, + temperature=temperature_cold, deterministic_synthesis=False, device=self.device, ) -- 2.39.5