Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 20:50:43 +0000 (22:50 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 20:50:43 +0000 (22:50 +0200)
main.py

diff --git a/main.py b/main.py
index f28c10a..2a35209 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -872,12 +872,24 @@ def test_ae(local_device=main_device):
 
             targets = input
 
-            input = (mask_generate == 0).long() * input + (
-                1 - (mask_generate == 0).long()
-            ) * torch.randint(
+            mask_noise = (mask_generate != 0) & (
+                torch.rand(mask_generate.size(), device=mask_generate.device)
+                <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+            )
+
+            mask_noise = mask_noise.long()
+
+            input = (1 - mask_noise) * input + mask_noise * torch.randint(
                 quiz_machine.problem.nb_colors, input.size(), device=input.device
             )
 
+            L = input.size(1) // 4
+
+            input[:, 0 * L] = targets[:, 0 * L]
+            input[:, 1 * L] = targets[:, 1 * L]
+            input[:, 2 * L] = targets[:, 2 * L]
+            input[:, 3 * L] = targets[:, 3 * L]
+
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), targets)
             acc_train_loss += loss.item() * input.size(0)
@@ -921,12 +933,26 @@ def test_ae(local_device=main_device):
 
                 targets = input
 
-                input = (mask_generate == 0).long() * input + (
-                    1 - (mask_generate == 0).long()
-                ) * torch.randint(
+                mask_noise = (mask_generate != 0) & (
+                    torch.rand(mask_generate.size(), device=mask_generate.device)
+                    <= torch.rand(
+                        (mask_generate.size(0), 1), device=mask_generate.device
+                    )
+                )
+
+                mask_noise = mask_noise.long()
+
+                input = (1 - mask_noise) * input + mask_noise * torch.randint(
                     quiz_machine.problem.nb_colors, input.size(), device=input.device
                 )
 
+                L = input.size(1) // 4
+
+                input[:, 0 * L] = targets[:, 0 * L]
+                input[:, 1 * L] = targets[:, 1 * L]
+                input[:, 2 * L] = targets[:, 2 * L]
+                input[:, 3 * L] = targets[:, 3 * L]
+
                 output = model(mygpt.BracketedSequence(input)).x
                 loss = F.cross_entropy(output.transpose(1, 2), targets)
                 acc_test_loss += loss.item() * input.size(0)
@@ -943,18 +969,28 @@ def test_ae(local_device=main_device):
             pred_result = None
             frozzen = None
 
-            result = (mask_generate == 0).long() * input + (
-                1 - (mask_generate == 0).long()
-            ) * torch.randint(
+            mask_noise = (mask_generate != 0) & (
+                torch.rand(mask_generate.size(), device=mask_generate.device)
+                <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+            )
+
+            mask_noise = mask_noise.long()
+
+            result = (1 - mask_noise) * input + mask_noise * torch.randint(
                 quiz_machine.problem.nb_colors, input.size(), device=input.device
             )
 
+            L = input.size(1) // 4
+
+            result[:, 0 * L] = input[:, 0 * L]
+            result[:, 1 * L] = input[:, 1 * L]
+            result[:, 2 * L] = input[:, 2 * L]
+            result[:, 3 * L] = input[:, 3 * L]
+
             i = torch.full((result.size(0),), True, device=result.device)
 
             nb_it = 0
 
-            L = input.size(1) // 4
-
             while True:
                 logits = model(mygpt.BracketedSequence(result)).x
                 dist = torch.distributions.categorical.Categorical(logits=logits)