Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 20:49:33 +0000 (22:49 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 20:49:33 +0000 (22:49 +0200)
main.py

diff --git a/main.py b/main.py
index 150010f..00f9175 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -101,7 +101,7 @@ parser.add_argument("--nb_models", type=int, default=5)
 
 parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
 
-parser.add_argument("--diffusion_delta", type=float, default=0.05)
+parser.add_argument("--diffusion_delta", type=float, default=0.1)
 
 parser.add_argument("--diffusion_epsilon", type=float, default=0.01)
 
@@ -779,17 +779,54 @@ def deterministic(mask_generate):
 
 ######################################################################
 
+torch.set_printoptions(
+    precision=None,
+    threshold=None,
+    edgeitems=None,
+    linewidth=500,
+    profile=None,
+    sci_mode=None,
+)
+
 N = quiz_machine.problem.nb_colors
-T = 50
-MP = torch.empty(T, N, N)
-MP[0] = torch.eye(N)
-MP[1, :, 0] = args.diffusion_epsilon / (N - 1)
-MP[1, 0, 0] = 1 - args.diffusion_epsilon
-MP[1, :, 1:] = args.diffusion_delta / (N - 1)
+T = args.nb_diffusion_iterations + 1
+diffusion_M = torch.empty(T, N, N)
+diffusion_M[0] = torch.eye(N)
+
+# i >0 j>0
+# P(X'=0 | X=0) = 1-epsilon
+# P(X'=i | X=0) = epsilon/(N-1)
+# P(X'=0 | X=i) = delta
+# P(X'=X | X=i) = 1-epsilon-delta
+# P(X'=j | X=i) = epsilon/(N-2)
+
+diffusion_M[1, 0, 0] = 1 - args.diffusion_epsilon
+diffusion_M[1, 1:, 0] = args.diffusion_epsilon / (N - 1)
+diffusion_M[1, 0, 1:] = args.diffusion_delta
+diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 2)
 for k in range(1, N):
-    MP[1, k, k] = 1 - args.diffusion_delta
+    diffusion_M[1, k, k] = 1 - args.diffusion_delta - args.diffusion_epsilon
+
+# m = diffusion_M[1]
+
+# print(m)
+# print(m.sum(dim=0))
+# print(torch.linalg.matrix_power(m, 25))
+
+# exit(0)
+
 for t in range(2, T):
-    MP[t] = MP[1] @ MP[t]
+    # diffusion_M[t] = diffusion_M[1] @ diffusion_M[t - 1]
+    diffusion_M[t] = torch.linalg.matrix_power(diffusion_M[1], t)
+
+# p = torch.full((N,), 1 / N)
+
+# for t in range(diffusion_M.size(0)):
+# print(diffusion_M[t] @ p)
+
+# print(diffusion_M[T-1])
+
+# exit(0)
 
 #
 # Given x_0 and t_0, t_1, ..., returns
@@ -798,11 +835,9 @@ for t in range(2, T):
 #
 
 
-def sample_x_t_given_x_0(x_0, t):
+def sample_x_t_given_x_0_(x_0, t):
     noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
-
     r = torch.rand(x_0.size(), device=x_0.device)
-
     proba_erased = 1 - (1 - args.diffusion_delta) ** t
     mask_erased = (r <= proba_erased[:, None]).long()
     x_t = (1 - mask_erased) * x_0 + mask_erased * noise
@@ -810,6 +845,15 @@ def sample_x_t_given_x_0(x_0, t):
     return x_t
 
 
+def sample_x_t_given_x_0(x_0, t):
+    D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
+    mask = (x_0 < quiz_machine.problem.nb_colors).long()
+    probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
+    dist = torch.distributions.categorical.Categorical(probs=probas)
+    x_t = (1 - mask) * x_0 + mask * dist.sample()
+    return x_t
+
+
 # This function returns a 2d tensor of same shape as low, full of
 # uniform random values in [0,1], such that, in every row, the values
 # corresponding to the True in low are all lesser than the values
@@ -825,7 +869,7 @@ def prioritized_rand(low):
     return y
 
 
-def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
+def sample_x_t_minus_1_given_x_0_x_t_(x_0, x_t):
     r = prioritized_rand(x_0 != x_t)
 
     mask_changes = (r <= args.diffusion_delta).long()
@@ -835,6 +879,35 @@ def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
     return x_t_minus_1
 
 
+def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
+    mask = (x_0 < quiz_machine.problem.nb_colors).long()
+
+    # i = x_0[n,s], j = x_t[n,s]
+    # probas[n,s,k] = M[1,x_t[n,s],k] M[t[n]-1,x_0[n,s],k] / M[t[n],x_0[n,s],x_t[n,s]]
+
+    # A[n,s,k] = M[1,x_t[n,s],k]
+    # B[n,s,k] = M[t[n]-1,x_0[n,s],k]
+    # C[n,s,k] = M[t[n],x_0[n,s],x_t[n,s]]
+    # probas = A * B / C
+
+    N, S, K = x_0.size(0), x_0.size(1), diffusion_M.size(1)
+
+    _1 = x_0.new_full((N, S, K), 1)
+    _t = x_0.new_full((N, S, K), t)
+    _k = torch.arange(K, device=x_0.device)[None, None, :].expand(N, S, K)
+    _x_t = (mask * x_t)[:, :, None].expand(N, S, K)
+    _x_0 = (mask * x_0)[:, :, None].expand(N, S, K)
+
+    M = diffusion_M.to(x_0.device)
+
+    probas = M[_1, _x_t, _k] * M[_t - 1, _x_0, _k] / M[_t, _x_0, _x_t]
+
+    dist = torch.distributions.categorical.Categorical(probs=probas)
+    x_t_minus_1 = (1 - mask) * x_0 + mask * dist.sample()
+
+    return x_t_minus_1
+
+
 ######################################################################
 
 # This function gets a clean target x_0, and a mask indicating which
@@ -862,6 +935,19 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
 
     x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
 
+    #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+    # filename = f"debug.png"
+
+    # quiz_machine.problem.save_quizzes_as_image(
+    # args.result_dir,
+    # filename,
+    # quizzes=x_t,
+    # )
+
+    # log_string(f"wrote {filename}")
+    # exit(0)
+    #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
     # We may inject noise to prevent high-complexity non-structure
     # signal to be generated as a way of "increasing reasoning
     # complexity"
@@ -886,9 +972,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
 
 
 def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
-    noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
+    noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
 
-    x_t = (1 - mask_generate) * x_0 + mask_generate * noise
+    x_t = (1 - mask_generate) * x_0  # + mask_generate * noise
 
     one_iteration_prediction = deterministic(mask_generate)[:, None]
 
@@ -897,13 +983,16 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
     for it in range(nb_iterations_max):
         x_t_with_mask = NTC_channel_cat(x_t, mask_generate)
         logits = model(x_t_with_mask)
+        logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf")
         dist = torch.distributions.categorical.Categorical(logits=logits)
 
         hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
 
         hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
             1 - one_iteration_prediction
-        ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
+        ) * sample_x_t_minus_1_given_x_0_x_t(
+            hat_x_0, x_t, max(1, args.nb_diffusion_iterations - it)
+        )
 
         if hat_x_t_minus_1.equal(x_t):
             # log_string(f"exit after {it+1} iterations")
@@ -1066,10 +1155,10 @@ def run_ae_test(
         # Save some images
 
         for f, record in [("prediction", record_d), ("generation", record_nd)]:
-            filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
-
             result, predicted_parts, correct_parts = bag_to_tensors(record)
 
+            filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+
             quiz_machine.problem.save_quizzes_as_image(
                 args.result_dir,
                 filename,
@@ -1447,7 +1536,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     # exit(0)
 
-    # one_ae_epoch(models[0], quiz_machine, n_epoch, main_device)
+    # one_ae_epoch(models[0], quiz_machine, n_epoch, None, main_device)
     # exit(0)
 
     log_string(f"{time_train=} {time_c_quizzes=}")