Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 5 Sep 2024 06:28:26 +0000 (08:28 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 5 Sep 2024 06:28:26 +0000 (08:28 +0200)
main.py

diff --git a/main.py b/main.py
index 02c9fc6..934940e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -57,11 +57,13 @@ parser.add_argument("--nb_train_samples", type=int, default=25000)
 
 parser.add_argument("--nb_test_samples", type=int, default=1000)
 
+parser.add_argument("--nb_c_quizzes", type=int, default=2500)
+
 parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
 
 parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
 
-parser.add_argument("--c_quiz_multiplier", type=int, default=10)
+parser.add_argument("--c_quiz_multiplier", type=int, default=1)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
@@ -115,7 +117,7 @@ parser.add_argument("--temperature_hot", type=float, default=1.5)
 
 parser.add_argument("--temperature_cold", type=float, default=1)
 
-parser.add_argument("--prompt_noise", type=float, default=0.0)
+parser.add_argument("--prompt_noise", type=float, default=0.05)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
@@ -754,10 +756,10 @@ def deterministic(mask_generate):
     return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
 
 
-# This function returns a tensor of same shape as low, full of uniform
-# random values in [0,1], such that the values corresponding to the
-# True in low are all lesser than the values corresponding to the
-# False.
+# This function returns a 2d tensor of same shape as low, full of
+# uniform random values in [0,1], such that, in every row, the values
+# corresponding to the True in low are all lesser than the values
+# corresponding to the False.
 
 
 def prioritized_rand(low):
@@ -840,7 +842,7 @@ def model_ae_proba_solutions(model, input, log_proba=False):
 nb_diffusion_iterations = 25
 
 
-def degrade_input_to_generate(input, mask_generate, nb_iterations):
+def degrade_input_to_generate(input, mask_generate, steps_nb_iterations):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
@@ -849,7 +851,7 @@ def degrade_input_to_generate(input, mask_generate, nb_iterations):
 
     result = []
 
-    for n in nb_iterations:
+    for n in steps_nb_iterations:
         proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
         mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
         x = (1 - mask_erased) * input + mask_erased * noise
@@ -929,8 +931,8 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_
             args.nb_test_samples,
             data_structures,
             local_device,
-            c_quizzes,
-            "test",
+            c_quizzes=c_quizzes,
+            desc="test",
         ):
             targets = input.clone()
             result = ae_generate(
@@ -1080,6 +1082,39 @@ for i in range(args.nb_models):
 ######################################################################
 
 
+def save_badness_statistics(
+    n_epoch, models, c_quizzes, suffix=None, local_device=main_device
+):
+    for model in models:
+        model.eval().to(local_device)
+    c_quizzes = c_quizzes.to(local_device)
+    with torch.autograd.no_grad():
+        log_probas = sum(
+            [model_ae_proba_solutions(model, c_quizzes) for model in models]
+        )
+        i = log_probas.sort().indices
+
+    suffix = "" if suffix is None else "_" + suffix
+
+    filename = f"culture_badness_{n_epoch:04d}{suffix}.png"
+
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir,
+        filename,
+        quizzes=c_quizzes[i[:128]],
+        # predicted_parts=predicted_parts,
+        # correct_parts=correct_parts,
+        # comments=comments,
+        delta=True,
+        nrow=8,
+    )
+
+    log_string(f"wrote {filename}")
+
+
+######################################################################
+
+
 def c_quiz_criterion_one_good_one_bad(probas):
     return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25)
 
@@ -1101,9 +1136,9 @@ def c_quiz_criterion_diff2(probas):
     return (v[:, -2] - v[:, 0]) >= 0.5
 
 
-def c_quiz_criterion_only_one(probas):
+def c_quiz_criterion_few_good_one_bad(probas):
     v = probas.sort(dim=1).values
-    return (v[:, -1] >= 0.75) & (v[:, -2] <= 0.25)
+    return (v[:, 0] <= 0.25) & (v[:, -3] >= 0.5)
 
 
 def c_quiz_criterion_two_good(probas):
@@ -1116,40 +1151,11 @@ def c_quiz_criterion_some(probas):
     )
 
 
-def save_badness_statistics(
-    n_epoch, models, c_quizzes, suffix=None, local_device=main_device
-):
-    for model in models:
-        model.eval().to(local_device)
-    c_quizzes = c_quizzes.to(local_device)
-    with torch.autograd.no_grad():
-        log_probas = sum(
-            [model_ae_proba_solutions(model, c_quizzes) for model in models]
-        )
-        i = log_probas.sort().indices
-
-    suffix = "" if suffix is None else "_" + suffix
-
-    filename = f"culture_badness_{n_epoch:04d}{suffix}.png"
-
-    quiz_machine.problem.save_quizzes_as_image(
-        args.result_dir,
-        filename,
-        quizzes=c_quizzes[i[:128]],
-        # predicted_parts=predicted_parts,
-        # correct_parts=correct_parts,
-        # comments=comments,
-        delta=True,
-        nrow=8,
-    )
-
-    log_string(f"wrote {filename}")
-
-
 def generate_ae_c_quizzes(models, nb, local_device=main_device):
     criteria = [
+        c_quiz_criterion_few_good_one_bad,
         # c_quiz_criterion_only_one,
-        c_quiz_criterion_one_good_one_bad,
+        c_quiz_criterion_one_good_one_bad,
         # c_quiz_criterion_one_good_no_very_bad,
         # c_quiz_criterion_diff,
         # c_quiz_criterion_diff2,
@@ -1186,22 +1192,22 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
             and min([bag_len(bag) for bag in records]) < wanted_nb
         ):
             model = models[torch.randint(len(models), (1,)).item()]
-            result = ae_generate(model, template, mask_generate)
+            c_quizzes = ae_generate(model, template, mask_generate)
 
-            to_keep = quiz_machine.problem.trivial(result) == False
-            result = result[to_keep]
+            to_keep = quiz_machine.problem.trivial(c_quizzes) == False
+            c_quizzes = c_quizzes[to_keep]
 
-            if result.size(0) > 0:
+            if c_quizzes.size(0) > 0:
                 probas = torch.cat(
                     [
-                        model_ae_proba_solutions(model, result)[:, None]
+                        model_ae_proba_solutions(model, c_quizzes)[:, None]
                         for model in models
                     ],
                     dim=1,
                 )
 
                 for c, r in zip(criteria, records):
-                    q = result[c(probas)]
+                    q = c_quizzes[c(probas)]
                     if q.size(0) > 0:
                         r.append(q)
 
@@ -1234,7 +1240,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
             quizzes = torch.cat(u, dim=0)[:nb_to_save]
             filename = f"culture_c_quiz_{n_epoch:04d}_{n:02d}.png"
 
-            # result, predicted_parts, correct_parts = bag_to_tensors(record)
+            # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record)
 
             l = [model_ae_proba_solutions(model, quizzes) for model in models]
             probas = torch.cat([x[:, None] for x in l], dim=1)
@@ -1351,9 +1357,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         last_n_epoch_c_quizzes = n_epoch
         nb_gpus = len(gpus)
-        nb_c_quizzes_to_generate = (
-            args.nb_train_samples // args.c_quiz_multiplier + nb_gpus - 1
-        ) // nb_gpus
+        nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
         # --------------------------------------------------------------------
 
@@ -1376,7 +1380,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         for t in threads:
             t.join()
 
-        time_c_quizzes = time.perf_counter() - start_time
+        time_c_quizzes = int(time.perf_counter() - start_time)
 
         c_quizzes = torch.cat([q.to(main_device) for q in records], dim=0)
 
@@ -1420,7 +1424,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     for t in threads:
         t.join()
 
-    time_train += time.perf_counter() - start_time
+    time_train += int(time.perf_counter() - start_time)
 
     # --------------------------------------------------------------------