From: François Fleuret Date: Mon, 1 Jul 2024 09:06:31 +0000 (+0300) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=60bf08d4197f2dd3a58bd900401c11d47225b0df;p=culture.git Update. --- diff --git a/main.py b/main.py index fd8ab41..67c57c0 100755 --- a/main.py +++ b/main.py @@ -437,7 +437,8 @@ def create_c_quizzes( for n in range(nb_correct.max() + 1): recorded[n].append(new_c_quizzes[nb_correct == n].clone()) - nv = [recorded[n][-1].size(0) for n in recorded.keys()] + nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) + nv = " ".join([str(x.item()) for x in nv]) log_string(f"keep c_quizzes kept {nv} total {nb_validated()} / {nb_to_create}") diff --git a/quizz_machine.py b/quizz_machine.py index 806dde7..6f7492d 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -386,8 +386,11 @@ class QuizzMachine: ar_mask_solve = 1 - ar_mask_prompt seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device) - warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) - temperature = 10 + if reverse_cleanup: + warnings.warn("very high temperature with reversed cleanup", RuntimeWarning) + temperature = 10.0 + else: + temperature = 1.0 # warnings.warn("noise injection", RuntimeWarning) # noise_std = torch.rand(1).item()