Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 10:27:13 +0000 (12:27 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 10:27:13 +0000 (12:27 +0200)
main.py

diff --git a/main.py b/main.py
index 1e398e8..4c30771 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -781,7 +781,7 @@ def deterministic(mask_generate):
 #
 
 
-def degrade_input_to_generate(x_0, steps_nb_iterations):
+def sample_x_t_given_x_0(x_0, steps_nb_iterations):
     noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
 
     r = torch.rand(x_0.size(), device=x_0.device)
@@ -797,8 +797,37 @@ def degrade_input_to_generate(x_0, steps_nb_iterations):
     return result
 
 
+# This function returns a 2d tensor of same shape as low, full of
+# uniform random values in [0,1], such that, in every row, the values
+# corresponding to the True in low are all lesser than the values
+# corresponding to the False.
+
+
+def prioritized_rand(low):
+    x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+    k = torch.rand(low.size(), device=low.device) + low.long()
+    k = k.sort(dim=1).indices
+    y = x.new(x.size())
+    y.scatter_(dim=1, index=k, src=x)
+    return y
+
+
+def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
+    r = prioritized_rand(x_0 != x_t)
+
+    mask_changes = (r <= args.diffusion_noise_proba).long()
+
+    x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
+
+    return result
+
+
 ######################################################################
 
+# This function gets a clean target x_0, and a mask indicating which
+# part to generate (conditionnaly to the others), and returns the
+# logits starting from a x_t|X_0=x_0 picked at random with t random
+
 
 def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise=0.0):
     # We favor iterations near the clean signal
@@ -811,9 +840,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
     probs_iterations = probs_iterations.expand(x_0.size(0), -1)
     dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
 
-    N1 = dist.sample() + 1
+    t_1 = dist.sample() + 1
 
-    (x_t,) = degrade_input_to_generate(x_0, (N1,))
+    (x_t,) = sample_x_t_given_x_0(x_0, (t_1,))
 
     # Only the part to generate is degraded, the rest is a perfect
     # noise-free conditionning
@@ -842,20 +871,6 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
 
 ######################################################################
 
-# This function returns a 2d tensor of same shape as low, full of
-# uniform random values in [0,1], such that, in every row, the values
-# corresponding to the True in low are all lesser than the values
-# corresponding to the False.
-
-
-def prioritized_rand(low):
-    x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
-    k = torch.rand(low.size(), device=low.device) + low.long()
-    k = k.sort(dim=1).indices
-    y = x.new(x.size())
-    y.scatter_(dim=1, index=k, src=x)
-    return y
-
 
 def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
     noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
@@ -873,13 +888,9 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
 
         hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
 
-        r = prioritized_rand(hat_x_0 != x_t)
-
-        mask_changes = (r <= args.diffusion_noise_proba).long()
-
-        hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
+        hat_x_t_minus_1 = one_iteration_prediction * x_0 + (
             1 - one_iteration_prediction
-        ) * ((1 - mask_changes) * x_t + mask_changes * hat_x_0)
+        ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
 
         if hat_x_t_minus_1.equal(x_t):
             # log_string(f"exit after {it+1} iterations")