Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 15:37:55 +0000 (17:37 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 15:37:55 +0000 (17:37 +0200)
main.py

diff --git a/main.py b/main.py
index 4a44fd3..a7f9c9e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -708,6 +708,13 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
 ######################################################################
 
 
+def identity_quizzes(quizzes):
+    quizzes = quizzes.reshape(quizzes.size(0), 4, -1)
+    return (quizzes[:, 0] == quizzes[:, 1]).min(dim=1).values & (
+        quizzes[:, 2] == quizzes[:, 3]
+    ).min(dim=1).values
+
+
 def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
     record = []
     nb_validated = 0
@@ -726,18 +733,21 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
             model=model, nb=args.eval_batch_size * 10, local_device=local_device
         )
 
-        # Select the ones that are solved properly by some models and
-        # not understood by others
+        c_quizzes = c_quizzes[identity_quizzes(c_quizzes) == False]
 
-        to_keep, nb_correct, nb_wrong = evaluate_quizzes(
-            quizzes=c_quizzes,
-            models=models,
-            fraction_with_hints=1.0,
-            local_device=local_device,
-        )
+        if c_quizzes.size(0) > 0:
+            # Select the ones that are solved properly by some models and
+            # not understood by others
+
+            to_keep, nb_correct, nb_wrong = evaluate_quizzes(
+                quizzes=c_quizzes,
+                models=models,
+                fraction_with_hints=1.0,
+                local_device=local_device,
+            )
 
-        nb_validated += to_keep.long().sum().item()
-        record.append(c_quizzes[to_keep])
+            nb_validated += to_keep.long().sum().item()
+            record.append(c_quizzes[to_keep])
 
         #####################