From 51bab563ee70462a71aac0b5326d2615dc531dec Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 29 Jul 2024 09:49:25 +0200 Subject: [PATCH] Update. --- quiz_machine.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/quiz_machine.py b/quiz_machine.py index 34f6b62..ca71c95 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -444,40 +444,45 @@ class QuizMachine: ############################################################### - def optimize_quizzes(self, quiz, nb_variants, nb_iterations, struct, mask): + def optimize_quizzes( + self, models, quiz, nb_variants, nb_iterations, struct, mask, proba_understands + ): for _ in range(nb_iterations): - candidates = quizzes[None].expand(nb_variants, -1) + candidates = quiz[None, :].expand(nb_variants, -1).clone() r = torch.rand(candidates.size(), device=candidates.device) u = r.reshape(r.size(0), 4, candidates.size(1) // 4) # Only change the part indicated by the mask and do not # touch the special tokens u[:, :, 0] = 0 - u = u * torch.tensor(mask, device=u.device)[None, :, None] - random_mask = (r.sort(dim=0, descending=True).indices == 0).long() + u = u * (1 - torch.tensor(mask, device=u.device)[None, :, None]) + random_mask = F.one_hot(r.argmax(dim=1), num_classes=r.size(1)) # Keep the first unchanged - random_mask[:, 0, :] = 0 + random_mask[0, :] = 0 # Reshape without the 4 parts candidates.reshape(-1, candidates.size(-1)) random_mask.reshape(candidates.size()) random_tokens = torch.randint( - self.problem.nb_token_values - 4, random_mask.size() + self.problem.nb_token_values - 4, + random_mask.size(), + device=candidates.device, ) # Apply the noise candidates = (1 - random_mask) * candidates + random_mask * random_tokens - seq_logproba = quiz_machine.models_logprobas( + seq_logproba = self.models_logprobas( models, candidates, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) - ) + quiz_machine.models_logprobas( + ) + self.models_logprobas( models, candidates, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) ) sorted_logprobas = seq_logproba.sort(dim=1).values.exp() lowest, second_lowest = sorted_logprobas[:, 0], sorted_logprobas[:, 1] score = second_lowest - lowest - score = score * (second_lowest > args.proba_understands) + # score = score * (second_lowest > proba_understands) quiz = candidates[score.argmax()] + print(score.max()) - return quiz + return quiz.to("cpu") def generate_c_quizzes(self, nb, model_for_generation, procedure, to_recycle=None): seq_logproba = torch.zeros(nb, device=self.device) -- 2.39.5