From 3dca75c7144421022e45cea9288cd87957ff5867 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 30 Jun 2024 13:10:56 +0300 Subject: [PATCH] Update. --- quizz_machine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/quizz_machine.py b/quizz_machine.py index c5870d0..18d0e0b 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -379,13 +379,14 @@ class QuizzMachine: ) ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device) - ar_mask_prompt[:, ar_mask_prompt.size(1) // 2 + 1] = 1 + ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1 ar_mask_solve = 1 - ar_mask_prompt - seq_logproba = torch.empty(ar_mask.size(0), device=self.device) + seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) # bracketing of the temperature to get the target logproba - temperature = 1 + warnings.warn("high temperature!", RuntimeWarning) + temperature = 2 d_temperature = 1 / 3 while True: -- 2.39.5