From 26529b3ce8cbad8b7e8e7a1bf0d5fe778b37670c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 6 Sep 2024 23:51:32 +0200 Subject: [PATCH] Update. --- main.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 00f9175..b926f8e 100755 --- a/main.py +++ b/main.py @@ -101,9 +101,9 @@ parser.add_argument("--nb_models", type=int, default=5) parser.add_argument("--nb_diffusion_iterations", type=int, default=25) -parser.add_argument("--diffusion_delta", type=float, default=0.1) +parser.add_argument("--diffusion_delta", type=float, default=0.05) -parser.add_argument("--diffusion_epsilon", type=float, default=0.01) +parser.add_argument("--diffusion_epsilon", type=float, default=0.05) parser.add_argument("--min_succeed_to_validate", type=int, default=2) @@ -802,8 +802,9 @@ diffusion_M[0] = torch.eye(N) diffusion_M[1, 0, 0] = 1 - args.diffusion_epsilon diffusion_M[1, 1:, 0] = args.diffusion_epsilon / (N - 1) -diffusion_M[1, 0, 1:] = args.diffusion_delta -diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 2) +diffusion_M[1, 0, 1:] = args.diffusion_epsilon / (N - 1) + args.diffusion_delta +diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 1) + for k in range(1, N): diffusion_M[1, k, k] = 1 - args.diffusion_delta - args.diffusion_epsilon @@ -835,7 +836,7 @@ for t in range(2, T): # -def sample_x_t_given_x_0_(x_0, t): +def sample_x_t_given_x_0(x_0, t): noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) r = torch.rand(x_0.size(), device=x_0.device) proba_erased = 1 - (1 - args.diffusion_delta) ** t @@ -845,7 +846,7 @@ def sample_x_t_given_x_0_(x_0, t): return x_t -def sample_x_t_given_x_0(x_0, t): +def ___sample_x_t_given_x_0(x_0, t): D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device) mask = (x_0 < quiz_machine.problem.nb_colors).long() probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1))) @@ -869,7 +870,7 @@ def prioritized_rand(low): return y -def sample_x_t_minus_1_given_x_0_x_t_(x_0, x_t): +def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t): r = prioritized_rand(x_0 != x_t) mask_changes = (r <= args.diffusion_delta).long() @@ -879,7 +880,7 @@ def sample_x_t_minus_1_given_x_0_x_t_(x_0, x_t): return x_t_minus_1 -def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t): +def ____sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t): mask = (x_0 < quiz_machine.problem.nb_colors).long() # i = x_0[n,s], j = x_t[n,s] @@ -972,9 +973,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): - # noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) + noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - x_t = (1 - mask_generate) * x_0 # + mask_generate * noise + x_t = (1 - mask_generate) * x_0 + mask_generate * noise one_iteration_prediction = deterministic(mask_generate)[:, None] -- 2.39.5