From 03fcaec34516db7ac941059d6c48737d378e9ff5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 22 Jul 2024 09:02:05 +0200 Subject: [PATCH] Update. --- quiz_machine.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/quiz_machine.py b/quiz_machine.py index c73b6d0..a5f9a89 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -529,8 +529,13 @@ class QuizMachine: seq_logproba = torch.zeros(nb, device=self.device) - def heater(T): - return lambda s, logits: logits / T + lt_noisy = lambda s, logits: logits / temperature_hot + lt_clean = lambda s, logits: logits / temperature_cold + + # lt_noisy = lambda s, logits: logits / ( + # 1 + 4 * (torch.rand(logits.size(), device=logits.device) < 1e-2).long() + # ) + # lt_clean = None if p2a_only: c_quizzes[...] = self.problem.token_forward @@ -541,7 +546,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_hot), + logit_transformer=lt_noisy, deterministic_synthesis=False, device=self.device, ) @@ -552,7 +557,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_cold), + logit_transformer=lt_clean, deterministic_synthesis=False, device=self.device, ) @@ -566,7 +571,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_012_bck_0"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_hot), + logit_transformer=lt_noisy, deterministic_synthesis=False, device=self.device, ) @@ -577,7 +582,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_cold), + logit_transformer=lt_clean, deterministic_synthesis=False, device=self.device, ) @@ -590,7 +595,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.problem.make_ar_mask(c_quizzes, shape="fwd_3_bck_123"), seq_logproba=seq_logproba, - logit_transformer=heater(temperature_cold), + logit_transformer=lt_clean, deterministic_synthesis=False, device=self.device, ) -- 2.39.5