From a4e9fa2ee2f964169aac25be98034d49602e37fd Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 16 Jul 2024 20:10:29 +0200 Subject: [PATCH] Update. --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 6df33bd..2b71950 100755 --- a/main.py +++ b/main.py @@ -412,7 +412,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 nb_validated = torch.zeros(len(models)) - while nb_validated < nb_to_create: + while nb_validated.sum() < nb_to_create: # We balance the number of quizzes per model model_for_generation = models[nb_validated.argmin()] @@ -427,7 +427,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10 c_quizzes = keep_good_quizzes(models, c_quizzes) nb_validated[model.id] += c_quizzes.size(0) - total_nb_validated = nb_validated.sum() + total_nb_validated = nb_validated.sum().item() recorded.append(c_quizzes) -- 2.39.5