From 354a99c2e20a6e5fea923a45477d0ea9ac3306a0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 10 Sep 2024 10:09:35 +0200 Subject: [PATCH] Update. --- main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index fe1aed1..fed8abc 100755 --- a/main.py +++ b/main.py @@ -1200,7 +1200,10 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): def save_c_quizzes_with_scores(models, c_quizzes, filename): - l = [model_ae_proba_solutions(model, c_quizzes) for model in models] + l = [] + for model in models: + model.eval().to(main_device) + l.append(model_ae_proba_solutions(model, c_quizzes)) probas = torch.cat([x[:, None] for x in l], dim=1) -- 2.39.5