From 6bc76685ddd1230e7944bdcb436eee0f12f5b968 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 22 Aug 2024 18:23:52 +0200 Subject: [PATCH] Update. --- main.py | 9 ++++++++- quiz_machine.py | 26 +++++++++++++++++--------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 35ba763..fc480b7 100755 --- a/main.py +++ b/main.py @@ -846,7 +846,13 @@ def test_ae(local_device=main_device): model.train() nb_train_samples, acc_train_loss = 0, 0.0 - full_input, full_mask_loss = quiz_machine.data_input(args.nb_train_samples) + data_structures = [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), + ] + + full_input, full_mask_loss = quiz_machine.data_input( + args.nb_train_samples, data_structures=data_structures + ) src = zip( full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) @@ -866,6 +872,7 @@ def test_ae(local_device=main_device): targets = input input = (mask_loss == 0).long() * input + output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), targets) acc_train_loss += loss.item() * input.size(0) diff --git a/quiz_machine.py b/quiz_machine.py index ceb527a..08f121a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -140,7 +140,12 @@ class QuizMachine: ###################################################################### - def data_input(self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1): + def data_input( + self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1, data_structures=None + ): + if data_structures is None: + data_structures = self.train_structures + if len(c_quiz_bags) > 0: c_quizzes = torch.cat(c_quiz_bags, dim=0) @@ -170,21 +175,24 @@ class QuizMachine: quizzes = quizzes[i] self.randomize_configuations_inplace( - quizzes, structs=[s for s, _, _, _ in self.train_structures] + quizzes, structs=[s for s, _, _, _ in data_structures] ) quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) - if self.prompt_noise > 0.0: - for struct, _, quad_noise, quad_loss in self.train_structures: - i = self.problem.indices_select(quizzes=quizzes, struct=struct) - if i.any(): + for struct, _, quad_noise, quad_loss in data_structures: + i = self.problem.indices_select(quizzes=quizzes, struct=struct) + if i.any(): + if self.prompt_noise > 0.0: quizzes[i] = self.problem.inject_noise( quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise ) - quiz_mask_loss[i] = self.make_quiz_mask( - quizzes=quizzes[i], struct=struct, quad=quad_loss - ) + quiz_mask_loss[i] = self.make_quiz_mask( + quizzes=quizzes[i], struct=struct, quad=quad_loss + ) + + print("quad_loss", quad_loss) + print("quiz_mask_loss", quiz_mask_loss) return quizzes, quiz_mask_loss -- 2.39.5