Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 24 Aug 2024 16:15:32 +0000 (18:15 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 24 Aug 2024 16:15:32 +0000 (18:15 +0200)
main.py

diff --git a/main.py b/main.py
index e13c148..9fe01ab 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -864,23 +864,6 @@ def ae_batches(
             mask_loss.to(local_device),
         )
 
-
-def degrade_input(input, mask_generate, *ts):
-    noise = torch.randint(
-        quiz_machine.problem.nb_colors, input.size(), device=input.device
-    )
-
-    r = torch.rand(mask_generate.size(), device=mask_generate.device)
-
-    result = []
-
-    for t in ts:
-        mask_diffusion_noise = mask_generate * (r <= t).long()
-        x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
-        result.append(x)
-
-    return result
-
     # quiz_machine.problem.save_quizzes_as_image(
     # args.result_dir,
     # filename="a.png",
@@ -921,6 +904,23 @@ def ae_generate(model, input, mask_generate, n_epoch, nb_iterations):
     return input
 
 
+def degrade_input(input, mask_generate, *phis):
+    noise = torch.randint(
+        quiz_machine.problem.nb_colors, input.size(), device=input.device
+    )
+
+    r = torch.rand(mask_generate.size(), device=mask_generate.device)
+
+    result = []
+
+    for phi in phis:
+        mask_diffusion_noise = mask_generate * (r <= phi).long()
+        x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
+        result.append(x)
+
+    return result
+
+
 def test_ae(local_device=main_device):
     model = MyAttentionAE(
         vocabulary_size=vocabulary_size,
@@ -949,6 +949,10 @@ def test_ae(local_device=main_device):
 
     nb_iterations = 10
 
+    def phi(rho):
+        # return (rho / nb_iterations)**2
+        return rho / nb_iterations
+
     for n_epoch in range(args.nb_epochs):
         # ----------------------
         # Train
@@ -967,9 +971,8 @@ def test_ae(local_device=main_device):
                 model.optimizer.zero_grad()
 
             rho = torch.randint(nb_iterations, (input.size(0), 1), device=input.device)
-            targets, input = degrade_input(
-                input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations
-            )
+
+            targets, input = degrade_input(input, mask_generate, phi(rho), phi(rho + 1))
 
             input_with_mask = NTC_channel_cat(input, mask_generate, rho)
             output = model(input_with_mask)
@@ -1004,9 +1007,11 @@ def test_ae(local_device=main_device):
                 rho = torch.randint(
                     nb_iterations, (input.size(0), 1), device=input.device
                 )
+
                 targets, input = degrade_input(
-                    input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations
+                    input, mask_generate, phi(rho), phi(rho + 1)
                 )
+
                 input_with_mask = NTC_channel_cat(input, mask_generate, rho)
                 output = model(input_with_mask)
                 loss = NTC_masked_cross_entropy(output, targets, mask_loss)