Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 10:16:21 +0000 (12:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 10:16:21 +0000 (12:16 +0200)
main.py

diff --git a/main.py b/main.py
index d1a1c8f..1e398e8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -732,7 +732,7 @@ def ae_batches(
 ):
     c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")]
 
-    full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
+    full_input, full_mask_generate, _ = quiz_machine.data_input(
         nb,
         c_quiz_bags,
         data_structures=data_structures,
@@ -742,7 +742,6 @@ def ae_batches(
     src = zip(
         full_input.split(batch_size),
         full_mask_generate.split(batch_size),
-        full_mask_loss.split(batch_size),
     )
 
     if desc is not None:
@@ -753,11 +752,10 @@ def ae_batches(
             total=full_input.size(0) // batch_size,
         )
 
-    for input, mask_generate, mask_loss in src:
+    for input, mask_generate in src:
         yield (
             input.to(local_device),
             mask_generate.to(local_device),
-            mask_loss.to(local_device),
         )
 
 
@@ -777,23 +775,23 @@ def deterministic(mask_generate):
 ######################################################################
 
 #
-# Given x_0 and t_0, t_1, ..., returns x_{t_0}, x_{t_1}, with
+# Given x_0 and t_0, t_1, ..., returns
 #
-#    x_{t_k} ~ P(X_{t_k} | X_0=x_0)
+#    x_{t_0}, ..., x_{t_K} ~ P(X_{t_0}, ..., X_{t_K} | X_0=x_0)
 #
 
 
-def degrade_input_to_generate(x0, mask_generate, steps_nb_iterations):
-    noise = torch.randint(quiz_machine.problem.nb_colors, x0.size(), device=x0.device)
+def degrade_input_to_generate(x_0, steps_nb_iterations):
+    noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
 
-    r = torch.rand(mask_generate.size(), device=mask_generate.device)
+    r = torch.rand(x_0.size(), device=x_0.device)
 
     result = []
 
     for n in steps_nb_iterations:
         proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
-        mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
-        x = (1 - mask_erased) * x0 + mask_erased * noise
+        mask_erased = (r <= proba_erased[:, None]).long()
+        x = (1 - mask_erased) * x_0 + mask_erased * noise
         result.append(x)
 
     return result
@@ -801,44 +799,45 @@ def degrade_input_to_generate(x0, mask_generate, steps_nb_iterations):
 
 ######################################################################
 
-# Given x_t and a mas
 
-
-def targets_and_logits(model, input, mask_generate, prompt_noise=0.0):
-    d = deterministic(mask_generate)
+def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise=0.0):
+    # We favor iterations near the clean signal
 
     probs_iterations = 0.1 ** torch.linspace(
-        0, 1, args.nb_diffusion_iterations, device=input.device
+        0, 1, args.nb_diffusion_iterations, device=x_0.device
     )
 
     probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
-    probs_iterations = probs_iterations.expand(input.size(0), -1)
+    probs_iterations = probs_iterations.expand(x_0.size(0), -1)
     dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
 
-    # N0 = dist.sample()
-    # N1 = N0 + 1
-    # N0 = (1 - d) * N0
-    # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
-
-    N0 = input.new_zeros(input.size(0))
     N1 = dist.sample() + 1
 
-    targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1))
+    (x_t,) = degrade_input_to_generate(x_0, (N1,))
+
+    # Only the part to generate is degraded, the rest is a perfect
+    # noise-free conditionning
+
+    x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
+
+    # We may inject noise to prevent high-complexity non-structure
+    # signal to be generated as a way of "increasing reasoning
+    # complexity"
 
     if prompt_noise > 0:
         mask_prompt_noise = (
-            torch.rand(input.size(), device=input.device) <= prompt_noise
+            torch.rand(x_t.size(), device=x_t.device) <= prompt_noise
         ).long()
         noise = torch.randint(
-            quiz_machine.problem.nb_colors, input.size(), device=input.device
+            quiz_machine.problem.nb_colors, x_t.size(), device=x_t.device
         )
-        noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise
-        input = (1 - mask_generate) * noisy_input + mask_generate * input
+        noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise
+        x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t
 
-    input_with_mask = NTC_channel_cat(input, mask_generate)
-    logits = model(input_with_mask)
+    x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
+    logits_hat_x_0 = model(x_t_with_mask)
 
-    return targets, logits
+    return logits_hat_x_0
 
 
 ######################################################################
@@ -858,42 +857,38 @@ def prioritized_rand(low):
     return y
 
 
-def ae_generate(model, input, mask_generate, nb_iterations_max=50):
-    noise = torch.randint(
-        quiz_machine.problem.nb_colors, input.size(), device=input.device
-    )
+def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
+    noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
 
-    input = (1 - mask_generate) * input + mask_generate * noise
+    x_t = (1 - mask_generate) * x_0 + mask_generate * noise
 
-    d = deterministic(mask_generate)[:, None]
+    one_iteration_prediction = deterministic(mask_generate)[:, None]
 
     changed = True
 
     for it in range(nb_iterations_max):
-        input_with_mask = NTC_channel_cat(input, mask_generate)
-        logits = model(input_with_mask)
+        x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
+        logits = model(x_t_with_mask)
         dist = torch.distributions.categorical.Categorical(logits=logits)
-        final = dist.sample()
 
-        r = prioritized_rand(final != input)
+        hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
 
-        mask_erased = mask_generate * (r <= args.diffusion_noise_proba).long()
+        r = prioritized_rand(hat_x_0 != x_t)
 
-        mask_to_change = d * mask_generate + (1 - d) * mask_erased
+        mask_changes = (r <= args.diffusion_noise_proba).long()
 
-        update = (1 - mask_to_change) * input + mask_to_change * final
+        hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
+            1 - one_iteration_prediction
+        ) * ((1 - mask_changes) * x_t + mask_changes * hat_x_0)
 
-        if update.equal(input):
+        if hat_x_t_minus_1.equal(x_t):
             # log_string(f"exit after {it+1} iterations")
             break
         else:
-            changed = changed & (update != input).max(dim=1).values
-            input[changed] = update[changed]
-
-    # if it == nb_iterations_max:
-    # log_string(f"remains {changed.long().sum()}")
+            changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
+            x_t[changed] = hat_x_t_minus_1[changed]
 
-    return input
+    return x_t
 
 
 ######################################################################
@@ -902,18 +897,18 @@ def ae_generate(model, input, mask_generate, nb_iterations_max=50):
 def model_ae_proba_solutions(model, input, log_proba=False):
     record = []
 
-    for q in input.split(args.batch_size):
+    for x_0 in input.split(args.batch_size):
         loss = 0
 
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            targets, logits = targets_and_logits(
-                model, q, mask_generate, prompt_noise=args.prompt_noise
+            logits = logits_hat_x_0_from_random_iteration(
+                model, x_0, mask_generate, prompt_noise=args.prompt_noise
             )
             loss_per_token = F.cross_entropy(
-                logits.transpose(1, 2), targets, reduction="none"
+                logits.transpose(1, 2), x_0, reduction="none"
             )
             loss += (loss_per_token * mask_generate).sum(dim=1)
         record.append(loss)
@@ -929,20 +924,20 @@ def model_ae_proba_solutions(model, input, log_proba=False):
 def model_ae_argmax_nb_disagreements(model, input):
     record = []
 
-    for q in input.split(args.batch_size):
+    for x_0 in input.split(args.batch_size):
         nb_disagreements = 0
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            targets, logits = targets_and_logits(
-                model, q, mask_generate, prompt_noise=args.prompt_noise
+            logits = logits_hat_x_0_from_random_iteration(
+                model, x_0, mask_generate, prompt_noise=args.prompt_noise
             )
 
             predicted = logits.argmax(dim=-1)
 
             nb_disagreements = nb_disagreements + (
-                mask_generate * predicted != mask_generate * targets
+                mask_generate * predicted != mask_generate * x_0
             ).long().sum(dim=1)
 
         record.append(nb_disagreements)
@@ -957,18 +952,18 @@ def model_ae_argmax_predictions(model, input):
     result = input.clone()
     # result[...] = 0
 
-    for r, q in zip(result.split(args.batch_size), input.split(args.batch_size)):
+    for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)):
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            targets, logits = targets_and_logits(
-                model, q, mask_generate, prompt_noise=args.prompt_noise
+            logits = logits_hat_x_0_from_random_iteration(
+                model, x_0, mask_generate, prompt_noise=args.prompt_noise
             )
 
-            predicted = logits.argmax(dim=-1)
+            hat_x_0 = logits.argmax(dim=-1)
 
-            r[...] = (1 - mask_generate) * r + mask_generate * predicted
+            r[...] = (1 - mask_generate) * r + mask_generate * hat_x_0
 
     return result
 
@@ -991,7 +986,7 @@ def run_ae_test(
 
         nb_test_samples, acc_test_loss = 0, 0.0
 
-        for input, mask_generate, mask_loss in ae_batches(
+        for x_0, mask_generate in ae_batches(
             quiz_machine,
             args.nb_test_samples,
             data_structures,
@@ -999,10 +994,10 @@ def run_ae_test(
             c_quizzes=c_quizzes,
             desc="test",
         ):
-            targets, logits = targets_and_logits(model, input, mask_generate)
-            loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
-            acc_test_loss += loss.item() * input.size(0)
-            nb_test_samples += input.size(0)
+            logits = logits_hat_x_0_from_random_iteration(model, x_0, mask_generate)
+            loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
+            acc_test_loss += loss.item() * x_0.size(0)
+            nb_test_samples += x_0.size(0)
 
         log_string(
             f"{prefix}test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}"
@@ -1012,7 +1007,7 @@ def run_ae_test(
 
         nb_correct, nb_total, record_d, record_nd = 0, 0, [], []
 
-        for input, mask_generate, mask_loss in ae_batches(
+        for x_0, mask_generate in ae_batches(
             quiz_machine,
             args.nb_test_samples,
             data_structures,
@@ -1020,13 +1015,12 @@ def run_ae_test(
             c_quizzes=c_quizzes,
             desc="test",
         ):
-            targets = input.clone()
             result = ae_generate(
                 model,
-                (1 - mask_generate) * input,
+                (1 - mask_generate) * x_0,
                 mask_generate,
             )
-            correct = (result == targets).min(dim=1).values.long()
+            correct = (result == x_0).min(dim=1).values.long()
             predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
                 :, :, 1
             ]
@@ -1052,49 +1046,16 @@ def run_ae_test(
 
             result, predicted_parts, correct_parts = bag_to_tensors(record)
 
-            # l = [model_ae_proba_solutions(model, result) for model in other_models]
-            # probas = torch.cat([x[:, None] for x in l], dim=1)
-            # comments = []
-
-            # for l in probas:
-            # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
-
             quiz_machine.problem.save_quizzes_as_image(
                 args.result_dir,
                 filename,
                 quizzes=result[:128],
                 predicted_parts=predicted_parts[:128],
                 correct_parts=correct_parts[:128],
-                # comments=comments,
             )
 
             log_string(f"wrote {filename}")
 
-        # Prediction with functional perturbations
-
-        # input, mask_generate, mask_loss = next(
-        # ae_batches(
-        # quiz_machine,
-        # [
-        # (
-        # ("A", "f_A", "B", "f_B"),
-        # (0, 0, 0, 1),
-        # (0, 0, 1, 0),
-        # (0, 0, 0, 1),
-        # ),
-        # ],
-        # local_device,
-        # desc=None,
-        # )
-        # )
-        # targets = input.clone()
-        # p = torch.rand(4,model.f_tokens.size(1)).sort(dim=1).indices
-        # def change_theta(theta_A, theta_B):
-        # theta
-        # result = ae_generate(
-        # model, (1 - mask_generate) * input, mask_generate
-        # )
-
 
 ######################################################################
 
@@ -1105,7 +1066,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    for input, mask_generate, mask_loss in ae_batches(
+    for x_0, mask_generate in ae_batches(
         quiz_machine,
         args.nb_train_samples,
         data_structures,
@@ -1113,20 +1074,19 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         c_quizzes,
         "training",
     ):
-        input = input.to(local_device)
+        x_0 = x_0.to(local_device)
         mask_generate = mask_generate.to(local_device)
-        mask_loss = mask_loss.to(local_device)
 
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
-        targets, logits = targets_and_logits(
-            model, input, mask_generate, prompt_noise=args.prompt_noise
+        logits = logits_hat_x_0_from_random_iteration(
+            model, x_0, mask_generate, prompt_noise=args.prompt_noise
         )
 
-        loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
-        acc_train_loss += loss.item() * input.size(0)
-        nb_train_samples += input.size(0)
+        loss = NTC_masked_cross_entropy(logits, x_0, mask_generate)
+        acc_train_loss += loss.item() * x_0.size(0)
+        nb_train_samples += x_0.size(0)
 
         loss.backward()