From 425f319d117cc333db298610c67ffd6a65c630e9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 10:30:59 +0200 Subject: [PATCH] Update. --- main.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 195afa8..77dcd2f 100755 --- a/main.py +++ b/main.py @@ -592,7 +592,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes ) - result = predict_full(model, quizzes, local_device=local_device) + result = predict_full(model=model, input=quizzes, local_device=local_device) quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result ) @@ -670,14 +670,14 @@ for i in range(args.nb_models): ###################################################################### -def evaluate_quizzes(c_quizzes, models, fraction_with_hints, local_device): +def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device): nb_correct, nb_wrong = 0, 0 for model in models: model = copy.deepcopy(model).to(local_device).eval() result = predict_full( model=model, - quizzes=c_quizzes, + input=c_quizzes, fraction_with_hints=fraction_with_hints, local_device=local_device, ) @@ -767,7 +767,9 @@ def save_quiz_image( ): c_quizzes = c_quizzes.to(local_device) - to_keep, nb_correct, nb_wrong = evaluate_quizzes(c_quizzes, models, local_device) + to_keep, nb_correct, nb_wrong = evaluate_quizzes( + quizzes=c_quizzes, models=models, local_device=local_device + ) if solvable_only: c_quizzes = c_quizzes[to_keep] -- 2.39.5