From: François Fleuret Date: Mon, 12 Aug 2024 14:44:46 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Fdev;p=culture.git Update. --- diff --git a/main.py b/main.py index f9dc35d..b23a52b 100755 --- a/main.py +++ b/main.py @@ -503,14 +503,14 @@ def model_transformer_cold(model): c_quizzes_procedure = [ - # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), - # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), - # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), - # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), + (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), # Generate the full thing at high temp - (("B", "f_B", "A", "f_A"), (1, 1, 1, 1), model_transformer_hot), + # (("B", "f_B", "A", "f_A"), (1, 1, 1, 1), model_transformer_hot), # Fix A and B - (("f_B", "B", "f_A", "A"), (0, 1, 0, 1), model_transformer_cold), + # (("f_B", "B", "f_A", "A"), (0, 1, 0, 1), model_transformer_cold), # Fix f_B # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), # Fix f_A diff --git a/quiz_machine.py b/quiz_machine.py index 92da03d..5bab1e5 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -53,7 +53,7 @@ def one_batch_masked_inplace_autoregression( all_n = torch.arange(t_next.size(0)) - seq_logproba += logits[all_n, t_next] + seq_logproba += logits.log_softmax(dim=1)[all_n, t_next] input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]