Update. dev
authorFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 14:44:46 +0000 (16:44 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 14:44:46 +0000 (16:44 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index f9dc35d..b23a52b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -503,14 +503,14 @@ def model_transformer_cold(model):
 
 
 c_quizzes_procedure = [
-    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
-    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
-    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
-    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
+    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
+    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
+    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
     # Generate the full thing at high temp
-    (("B", "f_B", "A", "f_A"), (1, 1, 1, 1), model_transformer_hot),
+    (("B", "f_B", "A", "f_A"), (1, 1, 1, 1), model_transformer_hot),
     # Fix A and B
-    (("f_B", "B", "f_A", "A"), (0, 1, 0, 1), model_transformer_cold),
+    (("f_B", "B", "f_A", "A"), (0, 1, 0, 1), model_transformer_cold),
     # Fix f_B
     # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
     # Fix f_A
index 92da03d..5bab1e5 100755 (executable)
@@ -53,7 +53,7 @@ def one_batch_masked_inplace_autoregression(
 
         all_n = torch.arange(t_next.size(0))
 
-        seq_logproba += logits[all_n, t_next]
+        seq_logproba += logits.log_softmax(dim=1)[all_n, t_next]
 
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]