From aeb45199c2fa89f505dac537924f564e1dc9c215 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 1 Aug 2024 23:13:22 +0200 Subject: [PATCH] Update. --- main.py | 3 +++ quiz_machine.py | 60 +++++++++++++++++++++++++++---------------------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/main.py b/main.py index 94c030a..69452e6 100755 --- a/main.py +++ b/main.py @@ -104,6 +104,8 @@ parser.add_argument("--temperature_cold", type=float, default=1) parser.add_argument("--prompt_noise", type=float, default=0.0) +parser.add_argument("--nb_averaging_rounds", type=int, default=1) + parser.add_argument("--dirty_debug", action="store_true", default=False) parser.add_argument("--test_generator", action="store_true", default=False) @@ -343,6 +345,7 @@ quiz_machine = quiz_machine.QuizMachine( batch_size=args.inference_batch_size, result_dir=args.result_dir, prompt_noise=args.prompt_noise, + nb_averaging_rounds=args.nb_averaging_rounds, logger=log_string, device=main_device, ) diff --git a/quiz_machine.py b/quiz_machine.py index 0fdfbf6..cfab73a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -68,6 +68,7 @@ class QuizMachine: batch_size, result_dir, prompt_noise, + nb_averaging_rounds, logger, device=torch.device("cpu"), ): @@ -79,7 +80,11 @@ class QuizMachine: self.logger = logger self.prompt_len = None self.answer_len = None + + assert prompt_noise > 0 or nb_averaging_rounds == 1 + self.prompt_noise = prompt_noise + self.nb_averaging_rounds = nb_averaging_rounds self.understood_structures = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)), @@ -338,39 +343,40 @@ class QuizMachine: c_quizzes = self.problem.reconfigure(c_quizzes, struct) - if self.prompt_noise > 0.0 and noise_mask is not None: - c_quizzes = self.problem.inject_noise( - c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask - ) - seq_logproba = torch.zeros( c_quizzes.size(0), max([m.id for m in models_for_validation]) + 1, device=device, ) - for model in models_for_validation: - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, l in zip( - c_quizzes.split(self.batch_size), - seq_logproba.split(self.batch_size), - ): - input = input.to(device) - ar_mask = self.make_ar_mask(input, struct=struct, mask=mask) - output = model(mygpt.BracketedSequence(input)).x - l[:, model.id] = ( - -F.cross_entropy( - output.transpose(1, 2), input, reduction="none" - ) - * ar_mask - ).sum(dim=1) - - model.train(t) - - return seq_logproba.to("cpu") + for a in range(self.nb_averaging_rounds): + if self.prompt_noise > 0.0 and noise_mask is not None: + c_quizzes = self.problem.inject_noise( + c_quizzes, self.prompt_noise, struct=struct, mask=noise_mask + ) + + for model in models_for_validation: + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), + seq_logproba.split(self.batch_size), + ): + input = input.to(device) + ar_mask = self.make_ar_mask(input, struct=struct, mask=mask) + output = model(mygpt.BracketedSequence(input)).x + l[:, model.id] += ( + -F.cross_entropy( + output.transpose(1, 2), input, reduction="none" + ) + * ar_mask + ).sum(dim=1) + + model.train(t) + + return seq_logproba.div(self.nb_averaging_rounds).to("cpu") ###################################################################### -- 2.39.5