Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 26 Aug 2024 07:12:39 +0000 (09:12 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 26 Aug 2024 07:12:39 +0000 (09:12 +0200)
main.py

diff --git a/main.py b/main.py
index ed36efb..3374a5b 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -909,7 +909,7 @@ def ae_generate(model, input, mask_generate, n_epoch, nb_iterations_max=50):
     return input
 
 
-def degrade_input(input, mask_generate, noise_levels):
+def degrade_input(input, mask_generate, nb_iterations, noise_proba=0.35):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
     )
@@ -918,9 +918,10 @@ def degrade_input(input, mask_generate, noise_levels):
 
     result = []
 
-    for phi in noise_levels:
-        mask_diffusion_noise = mask_generate * (r <= phi).long()
-        x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
+    for n in nb_iterations:
+        proba_erased = 1 - (1 - noise_proba) ** n
+        mask_erased = mask_generate * (r <= proba_erased[:, None]).long()
+        x = (1 - mask_erased) * input + mask_erased * noise
         result.append(x)
 
     return result
@@ -972,14 +973,18 @@ def test_ae(local_device=main_device):
                 model.optimizer.zero_grad()
 
             deterministic = (
-                mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2
+                mask_generate.sum(dim=1) < mask_generate.size(1) // 4
             ).long()
 
-            k = torch.randint(3, (input.size(0), 1), device=input.device)
-            phi0 = deterministic * 0 + (1 - deterministic) * (k / 3)
-            phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3)
+            N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device)
+            N1 = N0 + 1
 
-            targets, input = degrade_input(input, mask_generate, (phi0, phi1))
+            N0 = (1 - deterministic) * N0
+            N1 = deterministic * nb_iterations + (1 - deterministic) * N1
+
+            # print(f"{N0.size()=} {N1.size()=} {deterministic.size()=}")
+
+            targets, input = degrade_input(input, mask_generate, (N0, N1))
             input_with_mask = NTC_channel_cat(input, mask_generate)
             logits = model(input_with_mask)
             loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
@@ -1011,16 +1016,16 @@ def test_ae(local_device=main_device):
                 "test",
             ):
                 deterministic = (
-                    mask_generate.sum(dim=1, keepdim=True) < mask_generate.size(1) // 2
+                    mask_generate.sum(dim=1) < mask_generate.size(1) // 4
                 ).long()
 
-                k = torch.randint(3, (input.size(0), 1), device=input.device)
-                phi0 = deterministic * 0 + (1 - deterministic) * (k / 3)
-                phi1 = deterministic * 1 + (1 - deterministic) * ((k + 1) / 3)
+                N0 = torch.randint(nb_iterations, (input.size(0),), device=input.device)
+                N1 = N0 + 1
+
+                N0 = (1 - deterministic) * N0
+                N1 = deterministic * nb_iterations + (1 - deterministic) * N1
 
-                phi = torch.rand((input.size(0), 1), device=input.device)
-                phi = deterministic + (1 - deterministic) * phi
-                targets, input = degrade_input(input, mask_generate, (phi0, phi1))
+                targets, input = degrade_input(input, mask_generate, (N0, N1))
                 input_with_mask = NTC_channel_cat(input, mask_generate)
                 logits = model(input_with_mask)
                 loss = NTC_masked_cross_entropy(logits, targets, mask_loss)