Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 12:26:49 +0000 (14:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 12:26:49 +0000 (14:26 +0200)
diffusion.py
main.py

index 8c6e08d..2dc5861 100755 (executable)
@@ -52,13 +52,40 @@ class Diffuser:
 
     ######################################################################
 
+    def make_mask_hints(self, mask_generate, nb_hints):
+        if nb_hints == 0:
+            mask_hints = None
+        else:
+            u = (
+                torch.rand(mask_generate.size(), device=mask_generate.device)
+                * mask_generate
+            )
+            mask_hints = (
+                u > u.sort(dim=1, descending=True).values[:, nb_hints, None]
+            ).long()
+
+        return mask_hints
+
     # This function gets a clean target x_0, and a mask indicating which
     # part to generate (conditionnaly to the others), and returns the
     # logits starting from a x_t|X_0=x_0 picked at random with t random
 
     def logits_hat_x_0_from_random_iteration(
-        self, model, x_0, mask_generate, prompt_noise=0.0
+        self, model, x_0, mask_generate, nb_hints=0, prompt_noise=0.0
     ):
+        noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
+
+        single_iteration = (
+            mask_generate.sum(dim=1) < mask_generate.size(1) // 2
+        ).long()[:, None]
+
+        mask_hints = self.make_mask_hints(mask_generate, nb_hints)
+
+        if mask_hints is None:
+            mask_start = mask_generate
+        else:
+            mask_start = mask_generate * (1 - mask_hints)
+
         # We favor iterations near the clean signal
 
         probs_iterations = 0.1 ** torch.linspace(
@@ -71,7 +98,9 @@ class Diffuser:
 
         t = dist.sample() + 1
 
-        x_t = self.sample_x_t_given_x_0(x_0, t)
+        x_t = single_iteration * noise + (
+            1 - single_iteration
+        ) * self.sample_x_t_given_x_0(x_0, t)
 
         # Only the part to generate is degraded, the rest is a perfect
         # noise-free conditionning
@@ -99,13 +128,15 @@ class Diffuser:
 
     ######################################################################
 
-    def ae_generate(self, model, x_0, mask_generate, mask_hints=None):
+    def generate(self, model, x_0, mask_generate, nb_hints=0):
         noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
 
         single_iteration = (
             mask_generate.sum(dim=1) < mask_generate.size(1) // 2
         ).long()[:, None]
 
+        mask_hints = self.make_mask_hints(mask_generate, nb_hints)
+
         if mask_hints is None:
             mask_start = mask_generate
         else:
diff --git a/main.py b/main.py
index 1461ab1..d508c97 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -79,6 +79,12 @@ parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 parser.add_argument("--reboot", action="store_true", default=False)
 
+parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
+
+parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
+
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+
 # ----------------------------------
 
 parser.add_argument("--model", type=str, default="37M")
@@ -388,7 +394,7 @@ data_structures = [
 ######################################################################
 
 
-def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
+def model_proba_solutions(model, input, log_probas=False, reduce=True):
     record = []
 
     for x_0 in input.split(args.batch_size):
@@ -422,7 +428,7 @@ def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
 ######################################################################
 
 
-def ae_batches(
+def batches(
     quiz_machine,
     nb,
     data_structures,
@@ -469,7 +475,7 @@ def NTC_masked_cross_entropy(output, targets, mask):
 ######################################################################
 
 
-def run_ae_test(
+def run_test(
     model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None
 ):
     if prefix is None:
@@ -484,7 +490,7 @@ def run_ae_test(
 
         nb_test_samples, acc_test_loss = 0, 0.0
 
-        for x_0, mask_generate in ae_batches(
+        for x_0, mask_generate in batches(
             quiz_machine,
             args.nb_test_samples,
             data_structures,
@@ -509,7 +515,7 @@ def run_ae_test(
 
         nb_correct, nb_total, record_d, record_nd = 0, 0, [], []
 
-        for x_0, mask_generate in ae_batches(
+        for x_0, mask_generate in batches(
             quiz_machine,
             args.nb_test_samples,
             data_structures,
@@ -517,9 +523,7 @@ def run_ae_test(
             c_quizzes=c_quizzes,
             desc="test",
         ):
-            result = diffuser.ae_generate(
-                model, (1 - mask_generate) * x_0, mask_generate
-            )
+            result = diffuser.generate(model, (1 - mask_generate) * x_0, mask_generate)
             correct = (result == x_0).min(dim=1).values.long()
             predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
                 :, :, 1
@@ -561,7 +565,7 @@ def run_ae_test(
 ######################################################################
 
 
-def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device):
+def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device):
     model.train().to(local_device)
     optimizer_to(model.optimizer, local_device)
 
@@ -569,7 +573,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 
     # scaler = torch.amp.GradScaler("cuda")
 
-    for x_0, mask_generate in ae_batches(
+    for x_0, mask_generate in batches(
         quiz_machine,
         args.nb_train_samples,
         data_structures,
@@ -611,12 +615,12 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
     )
 
-    model.test_accuracy = run_ae_test(
+    model.test_accuracy = run_test(
         model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device
     )
 
     if args.nb_test_alien_samples > 0:
-        run_ae_test(
+        run_test(
             model,
             alien_quiz_machine,
             n_epoch,
@@ -662,12 +666,15 @@ def quiz_validation(
     models,
     c_quizzes,
     local_device,
-    nb_have_to_be_correct=3,
-    nb_have_to_be_wrong=1,
-    nb_mistakes_to_be_wrong=5,
+    nb_have_to_be_correct,
+    nb_have_to_be_wrong,
+    nb_mistakes_to_be_wrong,
     nb_hints=0,
     nb_runs=1,
 ):
+    ######################################################################
+    # If too many with process per-batch
+
     if c_quizzes.size(0) > args.inference_batch_size:
         record = []
         for q in c_quizzes.split(args.inference_batch_size):
@@ -684,9 +691,12 @@ def quiz_validation(
                 )
             )
 
-        return (torch.cat([tk for tk, _ in record], dim=0)), (
-            torch.cat([w for _, w in record], dim=0)
-        )
+        r = []
+        for k in range(len(record[0])):
+            r.append(torch.cat([x[k] for x in record], dim=0))
+
+        return tuple(r)
+    ######################################################################
 
     record_wrong = []
     nb_correct, nb_wrong = 0, 0
@@ -704,22 +714,11 @@ def quiz_validation(
 
             sub_correct, sub_wrong = False, True
             for _ in range(nb_runs):
-                if nb_hints == 0:
-                    mask_hints = None
-                else:
-                    u = (
-                        torch.rand(mask_generate.size(), device=mask_generate.device)
-                        * mask_generate
-                    )
-                    mask_hints = (
-                        u > u.sort(dim=1, descending=True).values[:, nb_hints, None]
-                    ).long()
-
-                result = ae_generate(
+                result = diffuser.generate(
                     model=model,
                     x_0=c_quizzes,
                     mask_generate=mask_generate,
-                    mask_hints=mask_hints,
+                    nb_hints=nb_hints,
                 )
 
                 nb_mistakes = (result != c_quizzes).long().sum(dim=1)
@@ -746,7 +745,7 @@ def quiz_validation(
 ######################################################################
 
 
-def generate_ae_c_quizzes(models, nb, local_device=main_device):
+def generate_c_quizzes(models, nb, local_device=main_device):
     # To be thread-safe we must make copies
 
     def copy_for_inference(model):
@@ -776,7 +775,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
                 quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
             )
 
-            c_quizzes = ae_generate(model, template, mask_generate)
+            c_quizzes = diffuser.generate(model, template, mask_generate)
 
             to_keep = quiz_machine.problem.trivial(c_quizzes) == False
             c_quizzes = c_quizzes[to_keep]
@@ -786,6 +785,9 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
                     models,
                     c_quizzes,
                     local_device,
+                    nb_have_to_be_correct=args.nb_have_to_be_correct,
+                    nb_have_to_be_wrong=args.nb_have_to_be_wrong,
+                    nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
                     nb_hints=args.nb_hints,
                     nb_runs=args.nb_runs,
                 )
@@ -839,21 +841,25 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False)
     c_quizzes = c_quizzes.to(main_device)
 
     with torch.autograd.no_grad():
+        to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
+            models,
+            c_quizzes,
+            main_device,
+            nb_have_to_be_correct=args.nb_have_to_be_correct,
+            nb_have_to_be_wrong=0,
+            nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
+            nb_hints=0,
+        )
+
         if solvable_only:
-            to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
-                models,
-                c_quizzes,
-                main_device,
-                nb_have_to_be_correct=2,
-                nb_have_to_be_wrong=0,
-                nb_hints=0,
-            )
             c_quizzes = c_quizzes[to_keep]
+            nb_correct = nb_correct[to_keep]
+            nb_wrong = nb_wrong[to_keep]
 
-    comments = []
+        comments = []
 
-    for c, w in zip(nb_correct, nb_wrong):
-        comments.append("nb_correct {c} nb_wrong {w}")
+        for c, w in zip(nb_correct, nb_wrong):
+            comments.append(f"nb_correct {c} nb_wrong {w}")
 
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
@@ -922,7 +928,7 @@ if args.quizzes is not None:
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            result = ae_generate(
+            result = generate(
                 model,
                 (1 - mask_generate) * quizzes,
                 mask_generate,
@@ -968,10 +974,9 @@ def multithread_execution(fun, arguments):
         records.append(fun(*args))
 
     for args in arguments:
-        t = threading.Thread(target=threadable_fun, daemon=True, args=args)
-
         # To get a different sequence between threads
         log_string(f"dummy_rand {torch.rand(1)}")
+        t = threading.Thread(target=threadable_fun, daemon=True, args=args)
         threads.append(t)
         t.start()
 
@@ -1039,7 +1044,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
         c_quizzes, agreements = multithread_execution(
-            generate_ae_c_quizzes,
+            generate_c_quizzes,
             [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
         )
 
@@ -1057,7 +1062,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             solvable_only=True,
         )
 
-        u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, 1:]
+        u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, :, 1:]
         i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices
 
         save_c_quizzes_with_scores(
@@ -1085,7 +1090,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     # None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
 
     multithread_execution(
-        one_ae_epoch,
+        one_epoch,
         [
             (model, quiz_machine, n_epoch, c_quizzes, gpu)
             for model, gpu in zip(weakest_models, gpus)