Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 20:02:09 +0000 (22:02 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 20:02:09 +0000 (22:02 +0200)
diffusion.py
main.py

index 2dc5861..abe8986 100755 (executable)
@@ -52,17 +52,18 @@ class Diffuser:
 
     ######################################################################
 
-    def make_mask_hints(self, mask_generate, nb_hints):
-        if nb_hints == 0:
+    def make_mask_hints(mask_generate, nb_hints):
+        if nb_hints is None:
             mask_hints = None
         else:
             u = (
                 torch.rand(mask_generate.size(), device=mask_generate.device)
                 * mask_generate
             )
-            mask_hints = (
-                u > u.sort(dim=1, descending=True).values[:, nb_hints, None]
-            ).long()
+            v = u.sort(dim=1, descending=True).values.gather(
+                dim=1, index=nb_hints[:, None]
+            )
+            mask_hints = (u > v).long()
 
         return mask_hints
 
@@ -71,7 +72,7 @@ class Diffuser:
     # logits starting from a x_t|X_0=x_0 picked at random with t random
 
     def logits_hat_x_0_from_random_iteration(
-        self, model, x_0, mask_generate, nb_hints=0, prompt_noise=0.0
+        self, model, x_0, mask_generate, nb_hints=None, prompt_noise=0.0
     ):
         noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
 
@@ -79,12 +80,7 @@ class Diffuser:
             mask_generate.sum(dim=1) < mask_generate.size(1) // 2
         ).long()[:, None]
 
-        mask_hints = self.make_mask_hints(mask_generate, nb_hints)
-
-        if mask_hints is None:
-            mask_start = mask_generate
-        else:
-            mask_start = mask_generate * (1 - mask_hints)
+        mask_hints = self.make_mask_hints(mask_generate, nb_hints) * single_iteration
 
         # We favor iterations near the clean signal
 
@@ -98,13 +94,9 @@ class Diffuser:
 
         t = dist.sample() + 1
 
-        x_t = single_iteration * noise + (
-            1 - single_iteration
-        ) * self.sample_x_t_given_x_0(x_0, t)
-
-        # Only the part to generate is degraded, the rest is a perfect
-        # noise-free conditionning
-
+        x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise
+        x_t = self.sample_x_t_given_x_0(x_0, t)
+        x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
         x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
 
         # We may inject noise to prevent high-complexity non-structure
@@ -128,7 +120,7 @@ class Diffuser:
 
     ######################################################################
 
-    def generate(self, model, x_0, mask_generate, nb_hints=0):
+    def generate(self, model, x_0, mask_generate, nb_hints=None):
         noise = self.mu_T_sampler(x_0.size(), device=x_0.device)
 
         single_iteration = (
@@ -137,12 +129,10 @@ class Diffuser:
 
         mask_hints = self.make_mask_hints(mask_generate, nb_hints)
 
-        if mask_hints is None:
-            mask_start = mask_generate
-        else:
-            mask_start = mask_generate * (1 - mask_hints)
-
-        x_t = (1 - mask_start) * x_0 + mask_start * noise
+        x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise
+        x_t = self.sample_x_t_given_x_0(x_0, t)
+        x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
+        x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
 
         changed = True
 
@@ -150,7 +140,6 @@ class Diffuser:
             x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
             with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
                 logits = model(x_t_with_mask)
-            # logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
             dist = torch.distributions.categorical.Categorical(logits=logits)
 
             hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
diff --git a/main.py b/main.py
index d508c97..0d46aa2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -405,7 +405,10 @@ def model_proba_solutions(model, input, log_probas=False, reduce=True):
                 quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
             logits = logits_hat_x_0_from_random_iteration(
-                model, x_0, mask_generate, prompt_noise=args.prompt_noise
+                model=model,
+                x_0=x_0,
+                mask_generate=mask_generate,
+                prompt_noise=args.prompt_noise,
             )
             loss_per_token = F.cross_entropy(
                 logits.transpose(1, 2), x_0, reduction="none"
@@ -543,21 +546,20 @@ def run_test(
 
         # Save some images
 
-        if n_epoch < 100:
-            for f, record in [("prediction", record_d), ("generation", record_nd)]:
-                result, predicted_parts, correct_parts = bag_to_tensors(record)
+        for f, record in [("prediction", record_d), ("generation", record_nd)]:
+            result, predicted_parts, correct_parts = bag_to_tensors(record)
 
-                filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+            filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
 
-                quiz_machine.problem.save_quizzes_as_image(
-                    args.result_dir,
-                    filename,
-                    quizzes=result[:128],
-                    predicted_parts=predicted_parts[:128],
-                    correct_parts=correct_parts[:128],
-                )
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir,
+                filename,
+                quizzes=result[:128],
+                predicted_parts=predicted_parts[:128],
+                correct_parts=correct_parts[:128],
+            )
 
-                log_string(f"wrote {filename}")
+            log_string(f"wrote {filename}")
 
         return nb_correct / nb_total
 
@@ -587,12 +589,15 @@ def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device)
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
+        nb_hints = torch.randint(2, (x_0.size(0),), device=x_0.device) * args.nb_hints
+
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
             logits = diffuser.logits_hat_x_0_from_random_iteration(
                 model=model,
                 x_0=x_0,
                 mask_generate=mask_generate,
                 prompt_noise=args.prompt_noise,
+                nb_hints=nb_hints,
             )
 
         loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
@@ -669,7 +674,7 @@ def quiz_validation(
     nb_have_to_be_correct,
     nb_have_to_be_wrong,
     nb_mistakes_to_be_wrong,
-    nb_hints=0,
+    nb_hints,
     nb_runs=1,
 ):
     ######################################################################
@@ -677,7 +682,10 @@ def quiz_validation(
 
     if c_quizzes.size(0) > args.inference_batch_size:
         record = []
-        for q in c_quizzes.split(args.inference_batch_size):
+        for q, nh in zip(
+            c_quizzes.split(args.inference_batch_size),
+            nb_hints.split(args.inference_batch_size),
+        ):
             record.append(
                 quiz_validation(
                     models=models,
@@ -686,7 +694,7 @@ def quiz_validation(
                     nb_have_to_be_correct=nb_have_to_be_correct,
                     nb_have_to_be_wrong=nb_have_to_be_wrong,
                     nb_mistakes_to_be_wrong=nb_mistakes_to_be_wrong,
-                    nb_hints=nb_hints,
+                    nb_hints=nh,
                     nb_runs=nb_runs,
                 )
             )
@@ -732,9 +740,6 @@ def quiz_validation(
         nb_correct += correct.long()
         nb_wrong += wrong.long()
 
-    # log_string(f"{nb_hints=} {nb_correct=}")
-    # log_string(f"{nb_hints=} {nb_wrong=}")
-
     to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
 
     wrong = torch.cat(record_wrong, dim=1)
@@ -780,6 +785,10 @@ def generate_c_quizzes(models, nb, local_device=main_device):
             to_keep = quiz_machine.problem.trivial(c_quizzes) == False
             c_quizzes = c_quizzes[to_keep]
 
+            nb_hints = torch.full(
+                (c_quizzes.size(0),), args.nb_hints, device=c_quizzes.device
+            )
+
             if c_quizzes.size(0) > 0:
                 to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
                     models,
@@ -788,7 +797,7 @@ def generate_c_quizzes(models, nb, local_device=main_device):
                     nb_have_to_be_correct=args.nb_have_to_be_correct,
                     nb_have_to_be_wrong=args.nb_have_to_be_wrong,
                     nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
-                    nb_hints=args.nb_hints,
+                    nb_hints=nb_hints,
                     nb_runs=args.nb_runs,
                 )
 
@@ -848,7 +857,7 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False)
             nb_have_to_be_correct=args.nb_have_to_be_correct,
             nb_have_to_be_wrong=0,
             nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
-            nb_hints=0,
+            nb_hints=None,
         )
 
         if solvable_only: