From c7d06243f8b4da35991b1c4f5a7a135a95df2958 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 29 Jul 2024 10:16:09 +0200 Subject: [PATCH] Update. --- main.py | 6 ++++-- mygpt.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index e553278..0d4cd10 100755 --- a/main.py +++ b/main.py @@ -453,11 +453,13 @@ def one_epoch(model, quiz_machine, local_device=main_device): def model_transformer_hot(model): - model.temperature = args.temperature_hot + # model.temperature = args.temperature_hot + model.set_noise_injection(1.0, ("ffw", 2)) def model_transformer_cold(model): - model.temperature = args.temperature_cold + pass + # model.temperature = args.temperature_cold c_quizzes_procedure = [ diff --git a/mygpt.py b/mygpt.py index c073113..7c51bae 100755 --- a/mygpt.py +++ b/mygpt.py @@ -371,7 +371,7 @@ class MyGPT(nn.Module): m.noise_std = 0.0 def set_noise_injection(self, noise_std, identifier=None): - for m in model.modules(): + for m in self.modules(): if isinstance(m, NoiseInjector): if identifier is None or identifier == m.identifier: m.noise_std = noise_std -- 2.39.5