From 46f9c0ca66b3cb2a63e15edbb180ff635d014150 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 6 Sep 2024 12:27:13 +0200 Subject: [PATCH] Update. --- main.py | 57 ++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/main.py b/main.py index 1e398e8..4c30771 100755 --- a/main.py +++ b/main.py @@ -781,7 +781,7 @@ def deterministic(mask_generate): # -def degrade_input_to_generate(x_0, steps_nb_iterations): +def sample_x_t_given_x_0(x_0, steps_nb_iterations): noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) r = torch.rand(x_0.size(), device=x_0.device) @@ -797,8 +797,37 @@ def degrade_input_to_generate(x_0, steps_nb_iterations): return result +# This function returns a 2d tensor of same shape as low, full of +# uniform random values in [0,1], such that, in every row, the values +# corresponding to the True in low are all lesser than the values +# corresponding to the False. + + +def prioritized_rand(low): + x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values + k = torch.rand(low.size(), device=low.device) + low.long() + k = k.sort(dim=1).indices + y = x.new(x.size()) + y.scatter_(dim=1, index=k, src=x) + return y + + +def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t): + r = prioritized_rand(x_0 != x_t) + + mask_changes = (r <= args.diffusion_noise_proba).long() + + x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0 + + return result + + ###################################################################### +# This function gets a clean target x_0, and a mask indicating which +# part to generate (conditionnaly to the others), and returns the +# logits starting from a x_t|X_0=x_0 picked at random with t random + def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise=0.0): # We favor iterations near the clean signal @@ -811,9 +840,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise probs_iterations = probs_iterations.expand(x_0.size(0), -1) dist = torch.distributions.categorical.Categorical(probs=probs_iterations) - N1 = dist.sample() + 1 + t_1 = dist.sample() + 1 - (x_t,) = degrade_input_to_generate(x_0, (N1,)) + (x_t,) = sample_x_t_given_x_0(x_0, (t_1,)) # Only the part to generate is degraded, the rest is a perfect # noise-free conditionning @@ -842,20 +871,6 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise ###################################################################### -# This function returns a 2d tensor of same shape as low, full of -# uniform random values in [0,1], such that, in every row, the values -# corresponding to the True in low are all lesser than the values -# corresponding to the False. - - -def prioritized_rand(low): - x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values - k = torch.rand(low.size(), device=low.device) + low.long() - k = k.sort(dim=1).indices - y = x.new(x.size()) - y.scatter_(dim=1, index=k, src=x) - return y - def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) @@ -873,13 +888,9 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample() - r = prioritized_rand(hat_x_0 != x_t) - - mask_changes = (r <= args.diffusion_noise_proba).long() - - hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + ( + hat_x_t_minus_1 = one_iteration_prediction * x_0 + ( 1 - one_iteration_prediction - ) * ((1 - mask_changes) * x_t + mask_changes * hat_x_0) + ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t) if hat_x_t_minus_1.equal(x_t): # log_string(f"exit after {it+1} iterations") -- 2.39.5