From 84abd2208e7112e96c68ca62b9905dd9f7d013dd Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 10:28:38 +0200 Subject: [PATCH] Update. --- main.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 86a3ae9..195afa8 100755 --- a/main.py +++ b/main.py @@ -428,7 +428,7 @@ def predict(model, imt_set, local_device=main_device, desc="predict"): return torch.cat(record) -def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device): +def predict_full(model, input, fraction_with_hints, local_device=main_device): input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) nb = input.size(0) masks = input.new_zeros(input.size()) @@ -670,12 +670,17 @@ for i in range(args.nb_models): ###################################################################### -def evaluate_quizzes(c_quizzes, models, local_device): +def evaluate_quizzes(c_quizzes, models, fraction_with_hints, local_device): nb_correct, nb_wrong = 0, 0 for model in models: model = copy.deepcopy(model).to(local_device).eval() - result = predict_full(model, c_quizzes, local_device=local_device) + result = predict_full( + model=model, + quizzes=c_quizzes, + fraction_with_hints=fraction_with_hints, + local_device=local_device, + ) nb_mistakes = (result != c_quizzes).long().sum(dim=1) nb_correct += (nb_mistakes == 0).long() nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong @@ -715,7 +720,10 @@ def generate_c_quizzes(models, nb, local_device=main_device): # not understood by others to_keep, nb_correct, nb_wrong = evaluate_quizzes( - c_quizzes, models, local_device + quizzes=c_quizzes, + models=models, + fraction_with_hints=1.0, + local_device=local_device, ) nb_validated += to_keep.long().sum().item() -- 2.39.5