From 0bc2b958fa73a84f2eef43e3242c8c3ef93d0207 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 6 Sep 2024 08:50:05 +0200 Subject: [PATCH] Update. --- main.py | 140 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 75 insertions(+), 65 deletions(-) diff --git a/main.py b/main.py index f609fd8..d1a1c8f 100755 --- a/main.py +++ b/main.py @@ -774,6 +774,75 @@ def deterministic(mask_generate): return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long() +###################################################################### + +# +# Given x_0 and t_0, t_1, ..., returns x_{t_0}, x_{t_1}, with +# +# x_{t_k} ~ P(X_{t_k} | X_0=x_0) +# + + +def degrade_input_to_generate(x0, mask_generate, steps_nb_iterations): + noise = torch.randint(quiz_machine.problem.nb_colors, x0.size(), device=x0.device) + + r = torch.rand(mask_generate.size(), device=mask_generate.device) + + result = [] + + for n in steps_nb_iterations: + proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n + mask_erased = mask_generate * (r <= proba_erased[:, None]).long() + x = (1 - mask_erased) * x0 + mask_erased * noise + result.append(x) + + return result + + +###################################################################### + +# Given x_t and a mas + + +def targets_and_logits(model, input, mask_generate, prompt_noise=0.0): + d = deterministic(mask_generate) + + probs_iterations = 0.1 ** torch.linspace( + 0, 1, args.nb_diffusion_iterations, device=input.device + ) + + probs_iterations = probs_iterations[None, :] / probs_iterations.sum() + probs_iterations = probs_iterations.expand(input.size(0), -1) + dist = torch.distributions.categorical.Categorical(probs=probs_iterations) + + # N0 = dist.sample() + # N1 = N0 + 1 + # N0 = (1 - d) * N0 + # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations + + N0 = input.new_zeros(input.size(0)) + N1 = dist.sample() + 1 + + targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1)) + + if prompt_noise > 0: + mask_prompt_noise = ( + torch.rand(input.size(), device=input.device) <= prompt_noise + ).long() + noise = torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) + noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise + input = (1 - mask_generate) * noisy_input + mask_generate * input + + input_with_mask = NTC_channel_cat(input, mask_generate) + logits = model(input_with_mask) + + return targets, logits + + +###################################################################### + # This function returns a 2d tensor of same shape as low, full of # uniform random values in [0,1], such that, in every row, the values # corresponding to the True in low are all lesser than the values @@ -840,7 +909,7 @@ def model_ae_proba_solutions(model, input, log_proba=False): mask_generate = quiz_machine.make_quiz_mask( quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - targets, logits = targets_and_prediction( + targets, logits = targets_and_logits( model, q, mask_generate, prompt_noise=args.prompt_noise ) loss_per_token = F.cross_entropy( @@ -866,7 +935,7 @@ def model_ae_argmax_nb_disagreements(model, input): mask_generate = quiz_machine.make_quiz_mask( quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - targets, logits = targets_and_prediction( + targets, logits = targets_and_logits( model, q, mask_generate, prompt_noise=args.prompt_noise ) @@ -893,7 +962,7 @@ def model_ae_argmax_predictions(model, input): mask_generate = quiz_machine.make_quiz_mask( quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad ) - targets, logits = targets_and_prediction( + targets, logits = targets_and_logits( model, q, mask_generate, prompt_noise=args.prompt_noise ) @@ -907,64 +976,6 @@ def model_ae_argmax_predictions(model, input): ###################################################################### -def degrade_input_to_generate(input, mask_generate, steps_nb_iterations): - noise = torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) - - r = torch.rand(mask_generate.size(), device=mask_generate.device) - - result = [] - - for n in steps_nb_iterations: - proba_erased = 1 - (1 - args.diffusion_noise_proba) ** n - mask_erased = mask_generate * (r <= proba_erased[:, None]).long() - x = (1 - mask_erased) * input + mask_erased * noise - result.append(x) - - return result - - -def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0): - d = deterministic(mask_generate) - - probs_iterations = 0.1 ** torch.linspace( - 0, 1, args.nb_diffusion_iterations, device=input.device - ) - - probs_iterations = probs_iterations[None, :] / probs_iterations.sum() - probs_iterations = probs_iterations.expand(input.size(0), -1) - dist = torch.distributions.categorical.Categorical(probs=probs_iterations) - - # N0 = dist.sample() - # N1 = N0 + 1 - # N0 = (1 - d) * N0 - # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations - - N0 = input.new_zeros(input.size(0)) - N1 = dist.sample() + 1 - - targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1)) - - if prompt_noise > 0: - mask_prompt_noise = ( - torch.rand(input.size(), device=input.device) <= prompt_noise - ).long() - noise = torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) - noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise - input = (1 - mask_generate) * noisy_input + mask_generate * input - - input_with_mask = NTC_channel_cat(input, mask_generate) - logits = model(input_with_mask) - - return targets, logits - - -###################################################################### - - def run_ae_test( model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None ): @@ -988,7 +999,7 @@ def run_ae_test( c_quizzes=c_quizzes, desc="test", ): - targets, logits = targets_and_prediction(model, input, mask_generate) + targets, logits = targets_and_logits(model, input, mask_generate) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) @@ -1032,8 +1043,7 @@ def run_ae_test( f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" ) - if prefix is None: - model.test_accuracy = nb_correct / nb_total + model.test_accuracy = nb_correct / nb_total # Save some images @@ -1110,7 +1120,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - targets, logits = targets_and_prediction( + targets, logits = targets_and_logits( model, input, mask_generate, prompt_noise=args.prompt_noise ) -- 2.39.5