Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:28:38 +0000 (10:28 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:28:38 +0000 (10:28 +0200)
main.py

diff --git a/main.py b/main.py
index 86a3ae9..195afa8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -428,7 +428,7 @@ def predict(model, imt_set, local_device=main_device, desc="predict"):
     return torch.cat(record)
 
 
-def predict_full(model, input, fraction_with_hints=0.0, local_device=main_device):
+def predict_full(model, input, fraction_with_hints, local_device=main_device):
     input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
     nb = input.size(0)
     masks = input.new_zeros(input.size())
@@ -670,12 +670,17 @@ for i in range(args.nb_models):
 ######################################################################
 
 
-def evaluate_quizzes(c_quizzes, models, local_device):
+def evaluate_quizzes(c_quizzes, models, fraction_with_hints, local_device):
     nb_correct, nb_wrong = 0, 0
 
     for model in models:
         model = copy.deepcopy(model).to(local_device).eval()
-        result = predict_full(model, c_quizzes, local_device=local_device)
+        result = predict_full(
+            model=model,
+            quizzes=c_quizzes,
+            fraction_with_hints=fraction_with_hints,
+            local_device=local_device,
+        )
         nb_mistakes = (result != c_quizzes).long().sum(dim=1)
         nb_correct += (nb_mistakes == 0).long()
         nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
@@ -715,7 +720,10 @@ def generate_c_quizzes(models, nb, local_device=main_device):
         # not understood by others
 
         to_keep, nb_correct, nb_wrong = evaluate_quizzes(
-            c_quizzes, models, local_device
+            quizzes=c_quizzes,
+            models=models,
+            fraction_with_hints=1.0,
+            local_device=local_device,
         )
 
         nb_validated += to_keep.long().sum().item()