Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 06:52:17 +0000 (08:52 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 11 Sep 2024 06:52:17 +0000 (08:52 +0200)
main.py

diff --git a/main.py b/main.py
index c1ef5bc..ed83a5c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -111,20 +111,12 @@ parser.add_argument("--diffusion_epsilon", type=float, default=0.05)
 
 parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
-parser.add_argument("--max_fail_to_validate", type=int, default=3)
-
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
-parser.add_argument("--proba_understands", type=float, default=0.95)
-
-parser.add_argument("--proba_not_understands", type=float, default=0.1)
-
-parser.add_argument("--temperature_hot", type=float, default=1.5)
-
-parser.add_argument("--temperature_cold", type=float, default=1)
-
 parser.add_argument("--prompt_noise", type=float, default=0.05)
 
+parser.add_argument("--nb_hints", type=int, default=5)
+
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
 parser.add_argument("--test", type=str, default=None)
@@ -705,19 +697,6 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
 
     x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
 
-    #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-    # filename = f"debug.png"
-
-    # quiz_machine.problem.save_quizzes_as_image(
-    # args.result_dir,
-    # filename,
-    # quizzes=x_t,
-    # )
-
-    # log_string(f"wrote {filename}")
-    # exit(0)
-    #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
     # We may inject noise to prevent high-complexity non-structure
     # signal to be generated as a way of "increasing reasoning
     # complexity"
@@ -741,10 +720,14 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
 ######################################################################
 
 
-def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
+def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None):
     noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
 
-    x_t = (1 - mask_generate) * x_0 + mask_generate * noise
+    if mask_hints is None:
+        x_t = (1 - mask_generate) * x_0 + mask_generate * noise
+    else:
+        mask = mask_generate * (1 - mask_hints)
+        x_t = (1 - mask) * x_0 + mask * noise
 
     one_iteration_prediction = deterministic(mask_generate)[:, None]
 
@@ -925,7 +908,7 @@ def run_ae_test(
 
         # Save some images
 
-        if n_epoch < 50:
+        if n_epoch < 100:
             for f, record in [("prediction", record_d), ("generation", record_nd)]:
                 result, predicted_parts, correct_parts = bag_to_tensors(record)
 
@@ -1062,28 +1045,52 @@ def save_badness_statistics(
 ######################################################################
 
 
-def quiz_validation(models, c_quizzes, local_device):
-    nb_have_to_be_correct = 3
-    nb_have_to_be_wrong = 1
-    nb_mistakes_to_be_wrong = 5
-
+def quiz_validation(
+    models,
+    c_quizzes,
+    local_device,
+    nb_have_to_be_correct=3,
+    nb_have_to_be_not_correct=0,
+    nb_have_to_be_wrong=1,
+    nb_mistakes_to_be_wrong=5,
+    nb_hints=0,
+    nb_runs=1,
+):
     record_wrong = []
     nb_correct, nb_wrong = 0, 0
 
     for i, model in enumerate(models):
         assert i == model.id  # a bit of paranoia
         model = copy.deepcopy(model).to(local_device).eval()
-
         correct, wrong = True, False
-
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=c_quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+                quizzes=c_quizzes,
+                quad_order=("A", "f_A", "B", "f_B"),
+                quad_mask=quad,
             )
-            result = ae_generate(model, (1 - mask_generate) * c_quizzes, mask_generate)
-            nb_mistakes = (result != c_quizzes).long().sum(dim=1)
-            correct = correct & (nb_mistakes == 0)
-            wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong)
+            for _ in range(nb_runs):
+                if nb_hints == 0:
+                    mask_hints = None
+                else:
+                    u = (
+                        torch.rand(mask_generate.size(), device=mask_generate.device)
+                        * mask_generate
+                    )
+                    mask_hints = (
+                        u > u.sort(dim=1, descending=True).values[:, nb_hints, None]
+                    ).long()
+
+                result = ae_generate(
+                    model=model,
+                    x_0=(1 - mask_generate) * c_quizzes,
+                    mask_generate=mask_generate,
+                    mask_hints=mask_hints,
+                )
+
+                nb_mistakes = (result != c_quizzes).long().sum(dim=1)
+                correct = correct & (nb_mistakes == 0)
+                wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong)
 
         record_wrong.append(wrong[:, None])
         nb_correct += correct.long()
@@ -1131,25 +1138,13 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
 
             c_quizzes = ae_generate(model, template, mask_generate)
 
-            #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-            ## for quad in [(0, 1, 0, 0), (0, 0, 0, 1)]:
-            ## mask_generate = quiz_machine.make_quiz_mask(
-            ## quizzes=c_quizzes,
-            ## quad_order=("A", "f_A", "B", "f_B"),
-            ## quad_mask=quad,
-            ## )
-            ## c_quizzes = ae_generate(
-            ## model,
-            ## (1 - mask_generate) * c_quizzes,
-            ## mask_generate,
-            ## )
-            #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
             to_keep = quiz_machine.problem.trivial(c_quizzes) == False
             c_quizzes = c_quizzes[to_keep]
 
             if c_quizzes.size(0) > 0:
-                to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device)
+                to_keep, record_wrong = quiz_validation(
+                    models, c_quizzes, local_device, nb_hints=args.nb_hints
+                )
                 q = c_quizzes[to_keep]
 
                 if q.size(0) > 0:
@@ -1195,9 +1190,24 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
 ######################################################################
 
 
-def save_c_quizzes_with_scores(models, c_quizzes, filename):
+def save_c_quizzes_with_scores(models, c_quizzes, nb, filename, solvable_only=False):
     l = []
 
+    if solvable_only:
+        to_keep, _ = quiz_validation(
+            models,
+            c_quizzes,
+            main_device,
+            nb_have_to_be_correct=1,
+            nb_have_to_be_wrong=0,
+            nb_hints=0,
+        )
+        c_quizzes = c_quizzes[to_keep]
+
+    c_quizzes = c_quizzes[
+        torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb]
+    ]
+
     with torch.autograd.no_grad():
         for model in models:
             model = copy.deepcopy(model).to(main_device).eval()
@@ -1389,7 +1399,11 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         # --------------------------------------------------------------------
 
         filename = f"culture_c_quiz_{n_epoch:04d}.png"
-        save_c_quizzes_with_scores(models, c_quizzes[:128], filename)
+        save_c_quizzes_with_scores(
+            models, c_quizzes, 256, filename, solvable_only=False
+        )
+        filename = f"culture_c_quiz_{n_epoch:04d}_solvable.png"
+        save_c_quizzes_with_scores(models, c_quizzes, 256, filename, solvable_only=True)
 
         log_string(f"generated_c_quizzes {c_quizzes.size()=}")