Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:23:50 +0000 (07:23 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:23:50 +0000 (07:23 +0200)
main.py

diff --git a/main.py b/main.py
index 1999bac..0fe33f6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -823,7 +823,7 @@ class MyAttentionVAE(nn.Module):
 
 def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
     full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
-        args.nb_train_samples, data_structures=data_structures
+        nb, data_structures=data_structures
     )
 
     src = zip(
@@ -894,14 +894,16 @@ def test_ae(local_device=main_device):
 
             targets = input
 
-            mask_noise = (mask_generate != 0) & (
+            mask_diffusion_noise = (mask_generate == 1) & (
                 torch.rand(mask_generate.size(), device=mask_generate.device)
                 <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
             )
 
-            mask_noise = mask_noise.long()
+            mask_diffusion_noise = mask_diffusion_noise.long()
 
-            input = (1 - mask_noise) * input + mask_noise * torch.randint(
+            input = (
+                1 - mask_diffusion_noise
+            ) * input + mask_diffusion_noise * torch.randint(
                 quiz_machine.problem.nb_colors, input.size(), device=input.device
             )
 
@@ -935,16 +937,18 @@ def test_ae(local_device=main_device):
             ):
                 targets = input
 
-                mask_noise = (mask_generate != 0) & (
+                mask_diffusion_noise = (mask_generate == 1) & (
                     torch.rand(mask_generate.size(), device=mask_generate.device)
                     <= torch.rand(
                         (mask_generate.size(0), 1), device=mask_generate.device
                     )
                 )
 
-                mask_noise = mask_noise.long()
+                mask_diffusion_noise = mask_diffusion_noise.long()
 
-                input = (1 - mask_noise) * input + mask_noise * torch.randint(
+                input = (
+                    1 - mask_diffusion_noise
+                ) * input + mask_diffusion_noise * torch.randint(
                     quiz_machine.problem.nb_colors, input.size(), device=input.device
                 )
 
@@ -955,6 +959,9 @@ def test_ae(local_device=main_device):
 
             log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
 
+            # -------------------------------------------
+            # Test generation
+
             input, mask_generate, mask_loss = next(
                 ae_batches(quiz_machine, 128, data_structures, local_device)
             )
@@ -962,17 +969,8 @@ def test_ae(local_device=main_device):
             targets = input
 
             input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
-
             pred_result = None
-
-            mask_noise = (mask_generate != 0) & (
-                torch.rand(mask_generate.size(), device=mask_generate.device)
-                <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
-            )
-
-            mask_noise = mask_noise.long()
-
-            result = (1 - mask_noise) * input + mask_noise * torch.randint(
+            result = (1 - mask_generate) * input + mask_generate * torch.randint(
                 quiz_machine.problem.nb_colors, input.size(), device=input.device
             )
 
@@ -984,7 +982,7 @@ def test_ae(local_device=main_device):
                 logits = model(mygpt.BracketedSequence(result)).x
                 dist = torch.distributions.categorical.Categorical(logits=logits)
                 pred_result = result.clone()
-                result[i] = (1 - mask_generate) * input + (
+                result[i] = (1 - mask_generate[i]) * input + (
                     mask_generate * dist.sample()[i]
                 )
                 changed = (pred_result == result).long().min(dim=1).values == 0