From 3dc7e89ca1e4ea2cba9c8aed412b1e90e6195841 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 12 Sep 2024 09:12:23 +0200 Subject: [PATCH] Update. --- main.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index c60b3c6..3753e9b 100755 --- a/main.py +++ b/main.py @@ -609,15 +609,6 @@ def sample_x_t_given_x_0(x_0, t): return x_t -def ___sample_x_t_given_x_0(x_0, t): - D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device) - mask = (x_0 < quiz_machine.problem.nb_colors).long() - probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1))) - dist = torch.distributions.categorical.Categorical(probs=probas) - x_t = (1 - mask) * x_0 + mask * dist.sample() - return x_t - - # This function returns a 2d tensor of same shape as low, full of # uniform random values in [0,1], such that, in every row, the values # corresponding to the True in low are all lesser than the values @@ -643,6 +634,18 @@ def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t): return x_t_minus_1 +# Non-uniform transitions, to be fixed? + + +def ___sample_x_t_given_x_0(x_0, t): + D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device) + mask = (x_0 < quiz_machine.problem.nb_colors).long() + probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1))) + dist = torch.distributions.categorical.Categorical(probs=probas) + x_t = (1 - mask) * x_0 + mask * dist.sample() + return x_t + + def ____sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t): mask = (x_0 < quiz_machine.problem.nb_colors).long() -- 2.39.5