From: François Fleuret Date: Mon, 26 Aug 2024 07:12:39 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=cd2fb834583a5576021689e05bb14e1bb4c0a727;p=culture.git Update. --- diff --git a/main.py b/main.py index ed36efb..3374a5b 100755 --- a/main.py +++ b/main.py @@ -909,7 +909,7 @@ def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50): return input -def degrade_input(input, mask_generate, noise_levels): +def degrade_input(input, mask_generate, nb_iterations, noise_proba=0.35): noise = torch.randint( quiz_machine.problem.nb_colors, input.size(), device=input.device ) @@ -918,9 +918,10 @@ def degrade_input(input, mask_generate, noise_levels): result = [] - for phi in noise_levels: - mask_diffusion_noise = mask_generate * (r <= phi).long() - x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise + for n in nb_iterations: + proba_erased = 1 - (1 - noise_proba) ** n + mask_erased = mask_generate * (r <= proba_erased[:, None]).long() + x = (1 - mask_erased) * input + mask_erased * noise result.append(x) return result @@ -972,14 +973,18 @@ def test_ae(local_device=main_device): model.optimizer.zero_grad() deterministic = ( - mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2 + mask_generate.sum(dim=1) < mask_generate.size(1) // 4 ).long() - k = torch.randint(3, (input.size(0), 1), device=input.device) - phi0 = deterministic * 0 + (1 - deterministic) * (k / 3) - phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3) + N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device) + N1 = N0 + 1 - targets, input = degrade_input(input, mask_generate, (phi0, phi1)) + N0 = (1 - deterministic) * N0 + N1 = deterministic * nb_iterations + (1 - deterministic) * N1 + + # print(f"{N0.size()=} {N1.size()=} {deterministic.size()=}") + + targets, input = degrade_input(input, mask_generate, (N0, N1)) input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) loss = NTC_masked_cross_entropy(logits, targets, mask_loss) @@ -1011,16 +1016,16 @@ def test_ae(local_device=main_device): "test", ): deterministic = ( - mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2 + mask_generate.sum(dim=1) < mask_generate.size(1) // 4 ).long() - k = torch.randint(3, (input.size(0), 1), device=input.device) - phi0 = deterministic * 0 + (1 - deterministic) * (k / 3) - phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3) + N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device) + N1 = N0 + 1 + + N0 = (1 - deterministic) * N0 + N1 = deterministic * nb_iterations + (1 - deterministic) * N1 - phi = torch.rand((input.size(0), 1), device=input.device) - phi = deterministic + (1 - deterministic) * phi - targets, input = degrade_input(input, mask_generate, (phi0, phi1)) + targets, input = degrade_input(input, mask_generate, (N0, N1)) input_with_mask = NTC_channel_cat(input, mask_generate) logits = model(input_with_mask) loss = NTC_masked_cross_entropy(logits, targets, mask_loss)