From 5c5668b0e52e2ae579d49ba8a44fafe2339ad8c0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 1 Jul 2024 10:11:09 +0300 Subject: [PATCH] Update. --- quizz_machine.py | 51 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/quizz_machine.py b/quizz_machine.py index 84bb558..7b0b877 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -312,9 +312,7 @@ class QuizzMachine: else: self.test_c_quizzes.append(new_c_quizzes) - def comput_correctness(self, c_quizzes, models_for_validation): - # Create the reverse quizzes - + def reverse_time(self, c_quizzes): token_forward, token_backward = self.problem.direction_tokens() l = (c_quizzes.size(1) - 1) // 2 @@ -322,9 +320,11 @@ class QuizzMachine: direction = self.problem.token_forward * ( direction == self.problem.token_backward ) + self.problem.token_backward * (direction == self.problem.token_forward) - reverse_c_quizzes = torch.cat( - [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1 - ) + + return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1) + + def comput_correctness(self, c_quizzes, models_for_validation): + reversed_c_quizzes = self.reverse_time(c_quizzes) ar_mask = self.make_ar_mask(c_quizzes) seq_logproba = torch.empty(ar_mask.size(0), device=self.device) @@ -350,12 +350,12 @@ class QuizzMachine: correct = (c_quizzes == result).long().min(dim=-1).values - reverse_result = reverse_c_quizzes.clone() + reversed_result = reversed_c_quizzes.clone() masked_inplace_autoregression( model=model, batch_size=self.batch_size, - input=reverse_result, + input=reversed_result, ar_mask=ar_mask, seq_logproba=seq_logproba, temperature=1.0, @@ -364,17 +364,19 @@ class QuizzMachine: device=self.device, ) - reverse_correct = ( - (reverse_c_quizzes == reverse_result).long().min(dim=-1).values + reversed_correct = ( + (reversed_c_quizzes == reversed_result).long().min(dim=-1).values ) - nb_correct += correct * reverse_correct + nb_correct += correct * reversed_correct return nb_correct ############################################################### - def generate_quizzes(self, nb, model_for_generation, min_ave_seq_logproba): + def generate_quizzes( + self, nb, model_for_generation, min_ave_seq_logproba, reverse_cleanup=False + ): c_quizzes = torch.empty( nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 ) @@ -384,11 +386,12 @@ class QuizzMachine: ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) - warnings.warn("noise injection", RuntimeWarning) + # warnings.warn("noise injection", RuntimeWarning) temperature = 1 - noise_std = torch.rand(1).item() - self.logger(f"{noise_std=}") - mygpt.set_noise_injection(model_for_generation, noise_std) + # noise_std = torch.rand(1).item() + # self.logger(f"{noise_std=}") + + # mygpt.set_noise_injection(model_for_generation, noise_std) masked_inplace_autoregression( model=model_for_generation, @@ -402,6 +405,8 @@ class QuizzMachine: device=self.device, ) + # mygpt.set_noise_injection(model_for_generation, 0.0) + ave_seq_logproba = seq_logproba.mean() masked_inplace_autoregression( @@ -416,7 +421,19 @@ class QuizzMachine: device=self.device, ) - mygpt.set_noise_injection(model_for_generation, 0.0) + if reverse_cleanup: + c_quizzes = self.reverse_time(c_quizzes) + masked_inplace_autoregression( + model=model_for_generation, + batch_size=self.batch_size, + input=c_quizzes, + ar_mask=ar_mask_solve, + seq_logproba=seq_logproba, + temperature=temperature, + deterministic_synthesis=True, + # progress_bar_desc="sampling c_quizzes", + device=self.device, + ) return c_quizzes, seq_logproba.mean() -- 2.39.5