Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 14:56:58 +0000 (16:56 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 14:56:58 +0000 (16:56 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index fbebbb9..e516a77 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -667,13 +667,16 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             )
 
         #!!!!!!!!!!!!!!!!!!!!
-        l = quiz_machine.models_logprobas(
-            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        for s in range(seq_logproba.size(0)):
-            print(f"-- {s=} ----------------")
-            for m in range(seq_logproba.size(1)):
-                print("DEBUG", seq_logproba[s, m].item(), l[s, m].item())
+        for m in range(seq_logproba.size(1)):
+            l = quiz_machine.models_logprobas(
+                [models[m]],
+                solved_c_quizzes[:, m, :],
+                ("A", "f_A", "B", "f_B"),
+                (0, 0, 0, 1),
+                (0, 0, 0, 0),
+            )
+            for s in range(seq_logproba.size(0)):
+                print("DEBUG", seq_logproba[s, m].item(), l[s, 0].item())
         exit(0)
         #!!!!!!!!!!!!!!!!!!!!!!!!!
 
index 6aa4e9b..b2287b8 100755 (executable)
@@ -53,7 +53,7 @@ def one_batch_masked_inplace_autoregression(
 
         all_n = torch.arange(t_next.size(0))
 
-        acc_seq_logproba += ar_mask[:, s] * logits[all_n, t_next]
+        acc_seq_logproba += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
 
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]