From 33f49356af24d00171dd4b041d11b1683c70b12d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 13 Aug 2024 18:01:40 +0200 Subject: [PATCH] Update. --- main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index b2a9591..df29152 100755 --- a/main.py +++ b/main.py @@ -602,7 +602,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): for s in range(proba_own_solution.size(0)): dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands - if not dont_get_this_quiz.all(): + nb_fails = dont_get_this_quiz.long().sum() + if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate: for model in models: if dont_get_this_quiz[model.id]: assert proba_own_solution[s, model.id] < args.proba_understands -- 2.39.5