Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 13:40:50 +0000 (15:40 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 13:40:50 +0000 (15:40 +0200)
main.py

diff --git a/main.py b/main.py
index ef340ea..7bdd09e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -662,7 +662,7 @@ for i in range(args.nb_models):
 ######################################################################
 
 
-def evaluate_quizzes(quizzes, models, with_perturbations, local_device):
+def evaluate_quizzes(quizzes, models, local_device):
     nb_correct, nb_wrong = 0, 0
 
     for model in models:
@@ -670,11 +670,17 @@ def evaluate_quizzes(quizzes, models, with_perturbations, local_device):
         result = predict_full(
             model=model,
             input=quizzes,
-            with_perturbations=with_perturbations,
+            with_perturbations=True,
             local_device=local_device,
         )
-        nb_mistakes = (result != quizzes).long().sum(dim=1)
         nb_correct += (nb_mistakes == 0).long()
+        result = predict_full(
+            model=model,
+            input=quizzes,
+            with_perturbations=False,
+            local_device=local_device,
+        )
+        nb_mistakes = (result != quizzes).long().sum(dim=1)
         nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
 
     to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
@@ -687,26 +693,6 @@ def evaluate_quizzes(quizzes, models, with_perturbations, local_device):
 ######################################################################
 
 
-def remove_old_problematic(c_quizzes, models, nb_to_remove, local_device):
-    nb_removed = 0
-    for input in c_quizzes.split(args.eval_batch_size):
-        _, nb_correct, nb_wrong = evaluate_quizzes(
-            quizzes=input,
-            models=models,
-            with_perturbations=False,
-            local_device=local_device,
-        )
-
-        to_remove = nb_wrong > 0
-        nb_removed += to_remove.long().sum()
-
-        if nb_removed >= nb_to_remove:
-            break
-
-
-######################################################################
-
-
 def identity_quizzes(quizzes):
     quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
     return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & (
@@ -741,7 +727,6 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
             to_keep, nb_correct, nb_wrong = evaluate_quizzes(
                 quizzes=c_quizzes,
                 models=models,
-                with_perturbations=True,
                 local_device=local_device,
             )
 
@@ -787,7 +772,6 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
     to_keep, nb_correct, nb_wrong = evaluate_quizzes(
         quizzes=c_quizzes,
         models=models,
-        with_perturbations=False,
         local_device=local_device,
     )