From: François Fleuret Date: Sat, 3 Aug 2024 04:58:21 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b080ee3874c6aa76b89268e12fcfdd87ead8bb92;p=culture.git Update. --- diff --git a/main.py b/main.py index 059a29d..63597b4 100755 --- a/main.py +++ b/main.py @@ -470,6 +470,8 @@ c_quizzes_procedure = [ (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), + # (("B", "f_B", "A", "f_A"), (0, 0, 1, 1), model_transformer_cold), + # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), ] ###################################################################### @@ -477,15 +479,30 @@ c_quizzes_procedure = [ def save_additional_results(models, science_w_quizzes): for model in models: + recorder = [] + c_quizzes = quiz_machine.generate_c_quizzes( - 128, model_for_generation=model, procedure=c_quizzes_procedure + 32, + model_for_generation=model, + procedure=c_quizzes_procedure, + recorder=recorder, ) + c_quizzes = torch.cat([c[:, None, :] for c, _, in recorder], dim=1) + predicted_parts = torch.cat([t[:, None, :] for _, t in recorder], dim=1) + nrow = c_quizzes.size(1) + c_quizzes = c_quizzes.reshape(-1, c_quizzes.size(-1)) + predicted_parts = predicted_parts.reshape(-1, predicted_parts.size(-1)) + + filename = f"non_validated_{n_epoch:04d}_{model.id:02d}.png" quiz_machine.problem.save_quizzes_as_image( args.result_dir, - f"non_validated_{n_epoch:04d}_{model.id:02d}.png", - c_quizzes, + filename, + quizzes=c_quizzes, + predicted_parts=predicted_parts, + nrow=nrow, ) + log_string(f"wrote {filename}") ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index 015f6d2..3fc1066 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -374,7 +374,9 @@ class QuizMachine: ###################################################################### - def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None): + def generate_c_quizzes( + self, nb, model_for_generation, procedure, to_recycle=None, recorder=None + ): seq_logproba = torch.zeros(nb, device=self.device) c_quizzes = None @@ -399,6 +401,13 @@ class QuizMachine: model_for_generation.reset_transformations() + if recorder is not None: + x = c_quizzes.clone() + t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1) + recorder.append( + self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B")) + ) + if to_recycle is not None and to_recycle.size(0) > 0: to_recycle = self.problem.reconfigure(to_recycle, s) c_quizzes[: to_recycle.size(0)] = to_recycle