Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 12 Sep 2024 07:12:23 +0000 (09:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 12 Sep 2024 07:12:23 +0000 (09:12 +0200)
main.py

diff --git a/main.py b/main.py
index c60b3c6..3753e9b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -609,15 +609,6 @@ def sample_x_t_given_x_0(x_0, t):
     return x_t
 
 
-def ___sample_x_t_given_x_0(x_0, t):
-    D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
-    mask = (x_0 < quiz_machine.problem.nb_colors).long()
-    probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
-    dist = torch.distributions.categorical.Categorical(probs=probas)
-    x_t = (1 - mask) * x_0 + mask * dist.sample()
-    return x_t
-
-
 # This function returns a 2d tensor of same shape as low, full of
 # uniform random values in [0,1], such that, in every row, the values
 # corresponding to the True in low are all lesser than the values
@@ -643,6 +634,18 @@ def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t):
     return x_t_minus_1
 
 
+# Non-uniform transitions, to be fixed?
+
+
+def ___sample_x_t_given_x_0(x_0, t):
+    D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device)
+    mask = (x_0 < quiz_machine.problem.nb_colors).long()
+    probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1)))
+    dist = torch.distributions.categorical.Categorical(probs=probas)
+    x_t = (1 - mask) * x_0 + mask * dist.sample()
+    return x_t
+
+
 def ____sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t):
     mask = (x_0 < quiz_machine.problem.nb_colors).long()