From 10d35403599fd2cd11ed5a3f5f1452c79b2ee67d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 6 Sep 2024 22:49:33 +0200 Subject: [PATCH] Update. --- main.py | 127 +++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 108 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 150010f..00f9175 100755 --- a/main.py +++ b/main.py @@ -101,7 +101,7 @@ parser.add_argument("--nb_models", type=int, default=5) parser.add_argument("--nb_diffusion_iterations", type=int, default=25) -parser.add_argument("--diffusion_delta", type=float, default=0.05) +parser.add_argument("--diffusion_delta", type=float, default=0.1) parser.add_argument("--diffusion_epsilon", type=float, default=0.01) @@ -779,17 +779,54 @@ def deterministic(mask_generate): ###################################################################### +torch.set_printoptions( + precision=None, + threshold=None, + edgeitems=None, + linewidth=500, + profile=None, + sci_mode=None, +) + N = quiz_machine.problem.nb_colors -T = 50 -MP = torch.empty(T, N, N) -MP[0] = torch.eye(N) -MP[1, :, 0] = args.diffusion_epsilon / (N - 1) -MP[1, 0, 0] = 1 - args.diffusion_epsilon -MP[1, :, 1:] = args.diffusion_delta / (N - 1) +T = args.nb_diffusion_iterations + 1 +diffusion_M = torch.empty(T, N, N) +diffusion_M[0] = torch.eye(N) + +# i >0 j>0 +# P(X'=0 | X=0) = 1-epsilon +# P(X'=i | X=0) = epsilon/(N-1) +# P(X'=0 | X=i) = delta +# P(X'=X | X=i) = 1-epsilon-delta +# P(X'=j | X=i) = epsilon/(N-2) + +diffusion_M[1, 0, 0] = 1 - args.diffusion_epsilon +diffusion_M[1, 1:, 0] = args.diffusion_epsilon / (N - 1) +diffusion_M[1, 0, 1:] = args.diffusion_delta +diffusion_M[1, 1:, 1:] = args.diffusion_epsilon / (N - 2) for k in range(1, N): - MP[1, k, k] = 1 - args.diffusion_delta + diffusion_M[1, k, k] = 1 - args.diffusion_delta - args.diffusion_epsilon + +# m = diffusion_M[1] + +# print(m) +# print(m.sum(dim=0)) +# print(torch.linalg.matrix_power(m, 25)) + +# exit(0) + for t in range(2, T): - MP[t] = MP[1] @ MP[t] + # diffusion_M[t] = diffusion_M[1] @ diffusion_M[t - 1] + diffusion_M[t] = torch.linalg.matrix_power(diffusion_M[1], t) + +# p = torch.full((N,), 1 / N) + +# for t in range(diffusion_M.size(0)): +# print(diffusion_M[t] @ p) + +# print(diffusion_M[T-1]) + +# exit(0) # # Given x_0 and t_0, t_1, ..., returns @@ -798,11 +835,9 @@ for t in range(2, T): # -def sample_x_t_given_x_0(x_0, t): +def sample_x_t_given_x_0_(x_0, t): noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - r = torch.rand(x_0.size(), device=x_0.device) - proba_erased = 1 - (1 - args.diffusion_delta) ** t mask_erased = (r <= proba_erased[:, None]).long() x_t = (1 - mask_erased) * x_0 + mask_erased * noise @@ -810,6 +845,15 @@ def sample_x_t_given_x_0(x_0, t): return x_t +def sample_x_t_given_x_0(x_0, t): + D = diffusion_M[t.to("cpu")].permute(0, 2, 1).to(x_0.device) + mask = (x_0 < quiz_machine.problem.nb_colors).long() + probas = D.gather(dim=1, index=(mask * x_0)[:, :, None].expand(-1, -1, D.size(-1))) + dist = torch.distributions.categorical.Categorical(probs=probas) + x_t = (1 - mask) * x_0 + mask * dist.sample() + return x_t + + # This function returns a 2d tensor of same shape as low, full of # uniform random values in [0,1], such that, in every row, the values # corresponding to the True in low are all lesser than the values @@ -825,7 +869,7 @@ def prioritized_rand(low): return y -def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t): +def sample_x_t_minus_1_given_x_0_x_t_(x_0, x_t): r = prioritized_rand(x_0 != x_t) mask_changes = (r <= args.diffusion_delta).long() @@ -835,6 +879,35 @@ def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t): return x_t_minus_1 +def sample_x_t_minus_1_given_x_0_x_t(x_0, x_t, t): + mask = (x_0 < quiz_machine.problem.nb_colors).long() + + # i = x_0[n,s], j = x_t[n,s] + # probas[n,s,k] = M[1,x_t[n,s],k] M[t[n]-1,x_0[n,s],k] / M[t[n],x_0[n,s],x_t[n,s]] + + # A[n,s,k] = M[1,x_t[n,s],k] + # B[n,s,k] = M[t[n]-1,x_0[n,s],k] + # C[n,s,k] = M[t[n],x_0[n,s],x_t[n,s]] + # probas = A * B / C + + N, S, K = x_0.size(0), x_0.size(1), diffusion_M.size(1) + + _1 = x_0.new_full((N, S, K), 1) + _t = x_0.new_full((N, S, K), t) + _k = torch.arange(K, device=x_0.device)[None, None, :].expand(N, S, K) + _x_t = (mask * x_t)[:, :, None].expand(N, S, K) + _x_0 = (mask * x_0)[:, :, None].expand(N, S, K) + + M = diffusion_M.to(x_0.device) + + probas = M[_1, _x_t, _k] * M[_t - 1, _x_0, _k] / M[_t, _x_0, _x_t] + + dist = torch.distributions.categorical.Categorical(probs=probas) + x_t_minus_1 = (1 - mask) * x_0 + mask * dist.sample() + + return x_t_minus_1 + + ###################################################################### # This function gets a clean target x_0, and a mask indicating which @@ -862,6 +935,19 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise x_t = (1 - mask_generate) * x_0 + mask_generate * x_t + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # filename = f"debug.png" + + # quiz_machine.problem.save_quizzes_as_image( + # args.result_dir, + # filename, + # quizzes=x_t, + # ) + + # log_string(f"wrote {filename}") + # exit(0) + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # We may inject noise to prevent high-complexity non-structure # signal to be generated as a way of "increasing reasoning # complexity" @@ -886,9 +972,9 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): - noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) + # noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - x_t = (1 - mask_generate) * x_0 + mask_generate * noise + x_t = (1 - mask_generate) * x_0 # + mask_generate * noise one_iteration_prediction = deterministic(mask_generate)[:, None] @@ -897,13 +983,16 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): for it in range(nb_iterations_max): x_t_with_mask = NTC_channel_cat(x_t, mask_generate) logits = model(x_t_with_mask) + logits[:, :, quiz_machine.problem.nb_colors :] = float("-inf") dist = torch.distributions.categorical.Categorical(logits=logits) hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample() hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + ( 1 - one_iteration_prediction - ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t) + ) * sample_x_t_minus_1_given_x_0_x_t( + hat_x_0, x_t, max(1, args.nb_diffusion_iterations - it) + ) if hat_x_t_minus_1.equal(x_t): # log_string(f"exit after {it+1} iterations") @@ -1066,10 +1155,10 @@ def run_ae_test( # Save some images for f, record in [("prediction", record_d), ("generation", record_nd)]: - filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png" - result, predicted_parts, correct_parts = bag_to_tensors(record) + filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png" + quiz_machine.problem.save_quizzes_as_image( args.result_dir, filename, @@ -1447,7 +1536,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): # exit(0) - # one_ae_epoch(models[0], quiz_machine, n_epoch, main_device) + # one_ae_epoch(models[0], quiz_machine, n_epoch, None, main_device) # exit(0) log_string(f"{time_train=} {time_c_quizzes=}") -- 2.39.5