Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 07:22:43 +0000 (09:22 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 07:22:43 +0000 (09:22 +0200)
main.py

diff --git a/main.py b/main.py
index ed83a5c..9e1726a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -117,6 +117,8 @@ parser.add_argument("--prompt_noise", type=float, default=0.05)
 
 parser.add_argument("--nb_hints", type=int, default=5)
 
+parser.add_argument("--nb_runs", type=int, default=5)
+
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 parser.add_argument("--test", type=str, default=None)
@@ -1050,7 +1052,6 @@ def quiz_validation(
     c_quizzes,
     local_device,
     nb_have_to_be_correct=3,
-    nb_have_to_be_not_correct=0,
     nb_have_to_be_wrong=1,
     nb_mistakes_to_be_wrong=5,
     nb_hints=0,
@@ -1069,6 +1070,8 @@ def quiz_validation(
                 quad_order=("A", "f_A", "B", "f_B"),
                 quad_mask=quad,
             )
+
+            sub_correct, sub_wrong = True, True
             for _ in range(nb_runs):
                 if nb_hints == 0:
                     mask_hints = None
@@ -1089,8 +1092,11 @@ def quiz_validation(
                 )
 
                 nb_mistakes = (result != c_quizzes).long().sum(dim=1)
-                correct = correct & (nb_mistakes == 0)
-                wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong)
+                sub_correct = sub_correct | (nb_mistakes == 0)
+                sub_wrong = sub_wrong & (nb_mistakes >= nb_mistakes_to_be_wrong)
+
+            correct = correct & sub_correct
+            wrong = wrong | sub_wrong
 
         record_wrong.append(wrong[:, None])
         nb_correct += correct.long()
@@ -1143,7 +1149,11 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
 
             if c_quizzes.size(0) > 0:
                 to_keep, record_wrong = quiz_validation(
-                    models, c_quizzes, local_device, nb_hints=args.nb_hints
+                    models,
+                    c_quizzes,
+                    local_device,
+                    nb_hints=args.nb_hints,
+                    nb_runs=args.nb_runs,
                 )
                 q = c_quizzes[to_keep]
 
@@ -1193,22 +1203,22 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
 def save_c_quizzes_with_scores(models, c_quizzes, nb, filename, solvable_only=False):
     l = []
 
-    if solvable_only:
-        to_keep, _ = quiz_validation(
-            models,
-            c_quizzes,
-            main_device,
-            nb_have_to_be_correct=1,
-            nb_have_to_be_wrong=0,
-            nb_hints=0,
-        )
-        c_quizzes = c_quizzes[to_keep]
+    with torch.autograd.no_grad():
+        if solvable_only:
+            to_keep, _ = quiz_validation(
+                models,
+                c_quizzes,
+                main_device,
+                nb_have_to_be_correct=1,
+                nb_have_to_be_wrong=0,
+                nb_hints=0,
+            )
+            c_quizzes = c_quizzes[to_keep]
 
-    c_quizzes = c_quizzes[
-        torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb]
-    ]
+        c_quizzes = c_quizzes[
+            torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb]
+        ]
 
-    with torch.autograd.no_grad():
         for model in models:
             model = copy.deepcopy(model).to(main_device).eval()
             l.append(model_ae_proba_solutions(model, c_quizzes))