From 4d54b8a610b40a1d34ad0d67a1aa49f831668bf8 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 24 Aug 2024 18:15:32 +0200 Subject: [PATCH] Update. --- main.py | 47 ++++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index e13c148..9fe01ab 100755 --- a/main.py +++ b/main.py @@ -864,23 +864,6 @@ def ae_batches( mask_loss.to(local_device), ) - -def degrade_input(input, mask_generate, *ts): - noise = torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) - - r = torch.rand(mask_generate.size(), device=mask_generate.device) - - result = [] - - for t in ts: - mask_diffusion_noise = mask_generate * (r <= t).long() - x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise - result.append(x) - - return result - # quiz_machine.problem.save_quizzes_as_image( # args.result_dir, # filename="a.png", @@ -921,6 +904,23 @@ def ae_generate(model, input, mask_generate, n_epoch, nb_iterations): return input +def degrade_input(input, mask_generate, *phis): + noise = torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) + + r = torch.rand(mask_generate.size(), device=mask_generate.device) + + result = [] + + for phi in phis: + mask_diffusion_noise = mask_generate * (r <= phi).long() + x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise + result.append(x) + + return result + + def test_ae(local_device=main_device): model = MyAttentionAE( vocabulary_size=vocabulary_size, @@ -949,6 +949,10 @@ def test_ae(local_device=main_device): nb_iterations = 10 + def phi(rho): + # return (rho / nb_iterations)**2 + return rho / nb_iterations + for n_epoch in range(args.nb_epochs): # ---------------------- # Train @@ -967,9 +971,8 @@ def test_ae(local_device=main_device): model.optimizer.zero_grad() rho = torch.randint(nb_iterations, (input.size(0), 1), device=input.device) - targets, input = degrade_input( - input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations - ) + + targets, input = degrade_input(input, mask_generate, phi(rho), phi(rho + 1)) input_with_mask = NTC_channel_cat(input, mask_generate, rho) output = model(input_with_mask) @@ -1004,9 +1007,11 @@ def test_ae(local_device=main_device): rho = torch.randint( nb_iterations, (input.size(0), 1), device=input.device ) + targets, input = degrade_input( - input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations + input, mask_generate, phi(rho), phi(rho + 1) ) + input_with_mask = NTC_channel_cat(input, mask_generate, rho) output = model(input_with_mask) loss = NTC_masked_cross_entropy(output, targets, mask_loss) -- 2.39.5