From 3836142b977d37c398a4c0cf8f6ff9405896e794 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 17:37:55 +0200 Subject: [PATCH] Update. --- main.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 4a44fd3..a7f9c9e 100755 --- a/main.py +++ b/main.py @@ -708,6 +708,13 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device): ###################################################################### +def identity_quizzes(quizzes): + quizzes = quizzes.reshape(quizzes.size(0), 4, -1) + return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & ( + quizzes[:, 2] == quizzes[:, 3] + ).min(dim=1).values + + def generate_c_quizzes(models, nb_to_generate, local_device=main_device): record = [] nb_validated = 0 @@ -726,18 +733,21 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): model=model, nb=args.eval_batch_size * 10, local_device=local_device ) - # Select the ones that are solved properly by some models and - # not understood by others + c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False] - to_keep, nb_correct, nb_wrong = evaluate_quizzes( - quizzes=c_quizzes, - models=models, - fraction_with_hints=1.0, - local_device=local_device, - ) + if c_quizzes.size(0) > 0: + # Select the ones that are solved properly by some models and + # not understood by others + + to_keep, nb_correct, nb_wrong = evaluate_quizzes( + quizzes=c_quizzes, + models=models, + fraction_with_hints=1.0, + local_device=local_device, + ) - nb_validated += to_keep.long().sum().item() - record.append(c_quizzes[to_keep]) + nb_validated += to_keep.long().sum().item() + record.append(c_quizzes[to_keep]) ##################### -- 2.39.5