Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 10 Sep 2024 08:09:35 +0000 (10:09 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 10 Sep 2024 08:09:35 +0000 (10:09 +0200)
main.py

diff --git a/main.py b/main.py
index fe1aed1..fed8abc 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1200,7 +1200,10 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
 
 
 def save_c_quizzes_with_scores(models, c_quizzes, filename):
-    l = [model_ae_proba_solutions(model, c_quizzes) for model in models]
+    l = []
+    for model in models:
+        model.eval().to(main_device)
+        l.append(model_ae_proba_solutions(model, c_quizzes))
 
     probas = torch.cat([x[:, None] for x in l], dim=1)