From 57207dc8d1e1927d3902b433f3a9731d0e6570b0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 1 Aug 2024 19:36:19 +0200 Subject: [PATCH] Update. --- main.py | 24 ++++++++++++++++-------- quiz_machine.py | 13 ++++++++++++- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 526da6f..94c030a 100755 --- a/main.py +++ b/main.py @@ -581,9 +581,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 # This is nb_quizzes x nb_models seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) ) probas = seq_logproba.exp() @@ -648,9 +648,9 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 if vq.size(0) > 0: seq_logproba = quiz_machine.models_logprobas( - models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) ) + quiz_machine.models_logprobas( - models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) ) comments = [] @@ -753,9 +753,17 @@ def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1. if c_quizzes.size(0) > 0: seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + models, + c_quizzes, + ("A", "f_A", "B", "f_B"), + (0, 0, 0, 1), + (0, 0, 1, 0), ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + models, + c_quizzes, + ("f_A", "A", "f_B", "B"), + (0, 0, 0, 1), + (0, 0, 1, 0), ) probas = seq_logproba.exp() @@ -1075,9 +1083,9 @@ if args.test_generator: ) seq_logproba = quiz_machine.models_logprobas( - models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) ) + quiz_machine.models_logprobas( - models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) ) probas = seq_logproba.exp() diff --git a/quiz_machine.py b/quiz_machine.py index b7c3b09..0fdfbf6 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -325,13 +325,24 @@ class QuizMachine: ###################################################################### def models_logprobas( - self, models_for_validation, c_quizzes, struct, mask, device=None + self, + models_for_validation, + c_quizzes, + struct, + mask, + noise_mask=None, + device=None, ): if device is None: device = self.device c_quizzes = self.problem.reconfigure(c_quizzes, struct) + if self.prompt_noise > 0.0 and noise_mask is not None: + c_quizzes = self.problem.inject_noise( + c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask + ) + seq_logproba = torch.zeros( c_quizzes.size(0), max([m.id for m in models_for_validation]) + 1, -- 2.39.5