Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 09:25:52 +0000 (11:25 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 09:25:52 +0000 (11:25 +0200)
main.py

diff --git a/main.py b/main.py
index 9e1726a..b83cabd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -886,11 +886,7 @@ def run_ae_test(
             c_quizzes=c_quizzes,
             desc="test",
         ):
-            result = ae_generate(
-                model,
-                (1 - mask_generate) * x_0,
-                mask_generate,
-            )
+            result = ae_generate(model, (1 - mask_generate) * x_0, mask_generate)
             correct = (result == x_0).min(dim=1).values.long()
             predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
                 :, :, 1
@@ -1057,6 +1053,26 @@ def quiz_validation(
     nb_hints=0,
     nb_runs=1,
 ):
+    if c_quizzes.size(0) > args.inference_batch_size:
+        record = []
+        for q in c_quizzes.split(args.inference_batch_size):
+            record.append(
+                quiz_validation(
+                    models,
+                    q,
+                    local_device,
+                    nb_have_to_be_correct,
+                    nb_have_to_be_wrong,
+                    nb_mistakes_to_be_wrong,
+                    nb_hints=0,
+                    nb_runs=1,
+                )
+            )
+
+        return (torch.cat([tk for tk, _ in record], dim=0)), (
+            torch.cat([w for _, w in record], dim=0)
+        )
+
     record_wrong = []
     nb_correct, nb_wrong = 0, 0
 
@@ -1086,7 +1102,7 @@ def quiz_validation(
 
                 result = ae_generate(
                     model=model,
-                    x_0=(1 - mask_generate) * c_quizzes,
+                    x_0=c_quizzes,
                     mask_generate=mask_generate,
                     mask_hints=mask_hints,
                 )