From d9ca928ea2d4b34b8900f38ee23347761f62c7d8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 12 Aug 2024 15:55:59 +0200 Subject: [PATCH] Update. --- main.py | 25 ++++++++++++++++++++++--- quiz_machine.py | 22 ++++++++++++---------- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index f51ab38..fbebbb9 100755 --- a/main.py +++ b/main.py @@ -574,7 +574,7 @@ def save_additional_results(model, models, science_w_quizzes): if science_w_quizzes is not None: struct = ("A", "f_A", "B", "f_B") mask = (0, 0, 0, 1) - result, correct = quiz_machine.predict( + result, correct, _ = quiz_machine.predict( model=model, quizzes=science_w_quizzes.to(main_device), struct=struct, @@ -650,14 +650,33 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone() + seq_logproba = torch.zeros( + c_quizzes.size(0), len(models), device=solved_c_quizzes.device + ) + for m in models: - solved_c_quizzes[:, m.id] = quiz_machine.predict( + ( + solved_c_quizzes[:, m.id], + _, + seq_logproba[:, m.id], + ) = quiz_machine.predict( m, solved_c_quizzes[:, m.id], struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1), ) + #!!!!!!!!!!!!!!!!!!!! + l = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + for s in range(seq_logproba.size(0)): + print(f"-- {s=} ----------------") + for m in range(seq_logproba.size(1)): + print("DEBUG", seq_logproba[s, m].item(), l[s, m].item()) + exit(0) + #!!!!!!!!!!!!!!!!!!!!!!!!! + # FINISH seq_logproba = quiz_machine.models_logprobas( @@ -1314,7 +1333,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): record_new_c_quizzes( models, quiz_machine, - nb_errorsfor_train=args.nb_new_c_quizzes_for_train, + nb_for_train=args.nb_new_c_quizzes_for_train, nb_for_test=args.nb_new_c_quizzes_for_test, ) diff --git a/quiz_machine.py b/quiz_machine.py index 1d89cf4..6aa4e9b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -28,7 +28,7 @@ def one_batch_masked_inplace_autoregression( model, input, ar_mask, - seq_logproba, + acc_seq_logproba, deterministic_synthesis=False, ): if input.size(0) == 0: @@ -53,7 +53,7 @@ def one_batch_masked_inplace_autoregression( all_n = torch.arange(t_next.size(0)) - seq_logproba += logits[all_n, t_next] + acc_seq_logproba += ar_mask[:, s] * logits[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] @@ -107,14 +107,11 @@ class QuizMachine: model, input, ar_mask, - seq_logproba=None, + seq_logproba, progress_bar_desc=None, ): assert input.size() == ar_mask.size() - if seq_logproba is None: - seq_logproba = torch.empty(input.size(0), device=self.device) - batches = zip( input.split(self.batch_size), ar_mask.split(self.batch_size), @@ -138,7 +135,7 @@ class QuizMachine: model=model, input=input, ar_mask=ar_mask, - seq_logproba=seq_logproba, + acc_seq_logproba=seq_logproba, deterministic_synthesis=False, ) @@ -190,10 +187,11 @@ class QuizMachine: ###################################################################### def predict(self, model, quizzes, struct, mask): + quizzes = quizzes.to(self.device) ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask) result = quizzes * (1 - ar_mask) - seq_logproba = torch.empty(quizzes.size(0), device=self.device) + seq_logproba = torch.zeros(quizzes.size(0), device=self.device) self.autoregression( model=model, @@ -205,7 +203,11 @@ class QuizMachine: correct = (result == quizzes).min(dim=1).values.long() - return result, correct + result = result.to("cpu") + correct = correct.to("cpu") + seq_logproba = seq_logproba.to("cpu") + + return result, correct, seq_logproba ###################################################################### @@ -221,7 +223,7 @@ class QuizMachine: for struct, mask_generate, _, _ in self.test_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() - result[i], correct[i] = self.predict( + result[i], correct[i], _ = self.predict( model=model, quizzes=input[i], struct=struct, mask=mask_generate ) predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[ -- 2.39.5