Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 10 Sep 2024 07:08:13 +0000 (09:08 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 10 Sep 2024 07:08:13 +0000 (09:08 +0200)
main.py

diff --git a/main.py b/main.py
index b7050df..fe1aed1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1064,7 +1064,7 @@ def save_badness_statistics(
 
 def quiz_validation(models, c_quizzes, local_device):
     nb_have_to_be_correct = 3
-    nb_have_to_be_wrong = 3
+    nb_have_to_be_wrong = 1
     nb_mistakes_to_be_wrong = 5
 
     record_wrong = []
@@ -1136,17 +1136,17 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
             c_quizzes = ae_generate(model, template, mask_generate)
 
             #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-            for quad in [(0, 1, 0, 0), (0, 0, 0, 1)]:
-                mask_generate = quiz_machine.make_quiz_mask(
-                    quizzes=c_quizzes,
-                    quad_order=("A", "f_A", "B", "f_B"),
-                    quad_mask=quad,
-                )
-                c_quizzes = ae_generate(
-                    model,
-                    (1 - mask_generate) * c_quizzes,
-                    mask_generate,
-                )
+            ## for quad in [(0, 1, 0, 0), (0, 0, 0, 1)]:
+            ## mask_generate = quiz_machine.make_quiz_mask(
+            ## quizzes=c_quizzes,
+            ## quad_order=("A", "f_A", "B", "f_B"),
+            ## quad_mask=quad,
+            ## )
+            ## c_quizzes = ae_generate(
+            ## model,
+            ## (1 - mask_generate) * c_quizzes,
+            ## mask_generate,
+            ## )
             #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
 
             to_keep = quiz_machine.problem.trivial(c_quizzes) == False
@@ -1347,8 +1347,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
                     },
                     os.path.join(args.result_dir, filename),
                 )
-
-            log_string(f"wrote {filename}")
+                log_string(f"wrote {filename}")
 
         # --------------------------------------------------------------------
 
@@ -1461,7 +1460,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             },
             os.path.join(args.result_dir, filename),
         )
-
         log_string(f"wrote {filename}")
 
     # --------------------------------------------------------------------