Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 12:55:45 +0000 (14:55 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 6 Sep 2024 12:55:45 +0000 (14:55 +0200)
main.py

diff --git a/main.py b/main.py
index 21c609c..150010f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -101,7 +101,9 @@ parser.add_argument("--nb_models", type=int, default=5)
 
 parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
 
-parser.add_argument("--diffusion_noise_proba", type=float, default=0.05)
+parser.add_argument("--diffusion_delta", type=float, default=0.05)
+
+parser.add_argument("--diffusion_epsilon", type=float, default=0.01)
 
 parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
@@ -284,6 +286,9 @@ else:
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
+######################################################################
+
+
 # ------------------------------------------------------
 alien_problem = grids.Grids(
     max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
@@ -774,6 +779,18 @@ def deterministic(mask_generate):
 
 ######################################################################
 
+N = quiz_machine.problem.nb_colors
+T = 50
+MP = torch.empty(T, N, N)
+MP[0] = torch.eye(N)
+MP[1, :, 0] = args.diffusion_epsilon / (N - 1)
+MP[1, 0, 0] = 1 - args.diffusion_epsilon
+MP[1, :, 1:] = args.diffusion_delta / (N - 1)
+for k in range(1, N):
+    MP[1, k, k] = 1 - args.diffusion_delta
+for t in range(2, T):
+    MP[t] = MP[1] @ MP[t]
+
 #
 # Given x_0 and t_0, t_1, ..., returns
 #
@@ -781,20 +798,16 @@ def deterministic(mask_generate):
 #
 
 
-def sample_x_t_given_x_0(x_0, steps_nb_iterations):
+def sample_x_t_given_x_0(x_0, t):
     noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
 
     r = torch.rand(x_0.size(), device=x_0.device)
 
-    result = []
-
-    for n in steps_nb_iterations:
-        proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n
-        mask_erased = (r <= proba_erased[:, None]).long()
-        x = (1 - mask_erased) * x_0 + mask_erased * noise
-        result.append(x)
+    proba_erased = 1 - (1 - args.diffusion_delta) ** t
+    mask_erased = (r <= proba_erased[:, None]).long()
+    x_t = (1 - mask_erased) * x_0 + mask_erased * noise
 
-    return result
+    return x_t
 
 
 # This function returns a 2d tensor of same shape as low, full of
@@ -815,7 +828,7 @@ def prioritized_rand(low):
 def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
     r = prioritized_rand(x_0 != x_t)
 
-    mask_changes = (r <= args.diffusion_noise_proba).long()
+    mask_changes = (r <= args.diffusion_delta).long()
 
     x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0
 
@@ -840,9 +853,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
     probs_iterations = probs_iterations.expand(x_0.size(0), -1)
     dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
 
-    t_1 = dist.sample() + 1
+    t = dist.sample() + 1
 
-    (x_t,) = sample_x_t_given_x_0(x_0, (t_1,))
+    x_t = sample_x_t_given_x_0(x_0, t)
 
     # Only the part to generate is degraded, the rest is a perfect
     # noise-free conditionning