From bf33ae9af88907ae2c5a4de6d8d90b286d931114 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 25 Aug 2024 15:55:54 +0200 Subject: [PATCH] Update. --- main.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index 7bd25cf..ed36efb 100755 --- a/main.py +++ b/main.py @@ -894,19 +894,22 @@ def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50): ) input = (1 - mask_generate) * input + mask_generate * noise + changed = True for it in range(nb_iterations_max): input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) dist = torch.distributions.categorical.Categorical(logits=logits) - pred_input = input.clone() - input = (1 - mask_generate) * input + mask_generate * dist.sample() - if (pred_input == input).min(): + update = (1 - mask_generate) * input + mask_generate * dist.sample() + if update.equal(input): break + else: + changed = changed & (update != input).max(dim=1).values + input[changed] = update[changed] return input -def degrade_input(input, mask_generate, *phis): +def degrade_input(input, mask_generate, noise_levels): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -915,7 +918,7 @@ def degrade_input(input, mask_generate, *phis): result = [] - for phi in phis: + for phi in noise_levels: mask_diffusion_noise = mask_generate * (r <= phi).long() x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise result.append(x) @@ -968,8 +971,15 @@ def test_ae(local_device=main_device): if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - phi = torch.rand((input.size(0), 1), device=input.device).clamp(min=0.25) - targets, input = degrade_input(input, mask_generate, phi - 0.25, phi) + deterministic = ( + mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2 + ).long() + + k = torch.randint(3, (input.size(0), 1), device=input.device) + phi0 = deterministic * 0 + (1 - deterministic) * (k / 3) + phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3) + + targets, input = degrade_input(input, mask_generate, (phi0, phi1)) input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) @@ -1000,10 +1010,17 @@ def test_ae(local_device=main_device): local_device, "test", ): - phi = torch.rand((input.size(0), 1), device=input.device).clamp( - min=0.25 - ) - targets, input = degrade_input(input, mask_generate, phi - 0.25, phi) + deterministic = ( + mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2 + ).long() + + k = torch.randint(3, (input.size(0), 1), device=input.device) + phi0 = deterministic * 0 + (1 - deterministic) * (k / 3) + phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3) + + phi = torch.rand((input.size(0), 1), device=input.device) + phi = deterministic + (1 - deterministic) * phi + targets, input = degrade_input(input, mask_generate, (phi0, phi1)) input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) -- 2.39.5