Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 22:07:15 +0000 (00:07 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 22:07:15 +0000 (00:07 +0200)
main.py

diff --git a/main.py b/main.py
index 0a79323..bd46948 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -540,8 +540,7 @@ def save_additional_results(model, models, science_w_quizzes):
         for model in models
     ]
 
-    seq_logprobas = torch.cat([x[None, :] for x in l])
-
+    seq_logprobas = torch.cat([x[:, None] for x in l], dim=1)
     probas = seq_logprobas.exp()
 
     comments = []