Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:16:55 +0000 (07:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:16:55 +0000 (07:16 +0200)
main.py

diff --git a/main.py b/main.py
index a65d893..1999bac 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -905,13 +905,6 @@ def test_ae(local_device=main_device):
                 quiz_machine.problem.nb_colors, input.size(), device=input.device
             )
 
-            L = input.size(1) // 4
-
-            input[:, 0 * L] = targets[:, 0 * L]
-            input[:, 1 * L] = targets[:, 1 * L]
-            input[:, 2 * L] = targets[:, 2 * L]
-            input[:, 3 * L] = targets[:, 3 * L]
-
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), targets)
             acc_train_loss += loss.item() * input.size(0)
@@ -955,13 +948,6 @@ def test_ae(local_device=main_device):
                     quiz_machine.problem.nb_colors, input.size(), device=input.device
                 )
 
-                L = input.size(1) // 4
-
-                input[:, 0 * L] = targets[:, 0 * L]
-                input[:, 1 * L] = targets[:, 1 * L]
-                input[:, 2 * L] = targets[:, 2 * L]
-                input[:, 3 * L] = targets[:, 3 * L]
-
                 output = model(mygpt.BracketedSequence(input)).x
                 loss = F.cross_entropy(output.transpose(1, 2), targets)
                 acc_test_loss += loss.item() * input.size(0)
@@ -975,8 +961,9 @@ def test_ae(local_device=main_device):
 
             targets = input
 
+            input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
+
             pred_result = None
-            frozzen = None
 
             mask_noise = (mask_generate != 0) & (
                 torch.rand(mask_generate.size(), device=mask_generate.device)
@@ -989,13 +976,6 @@ def test_ae(local_device=main_device):
                 quiz_machine.problem.nb_colors, input.size(), device=input.device
             )
 
-            L = input.size(1) // 4
-
-            result[:, 0 * L] = input[:, 0 * L]
-            result[:, 1 * L] = input[:, 1 * L]
-            result[:, 2 * L] = input[:, 2 * L]
-            result[:, 3 * L] = input[:, 3 * L]
-
             i = torch.full((result.size(0),), True, device=result.device)
 
             nb_it = 0
@@ -1004,11 +984,9 @@ def test_ae(local_device=main_device):
                 logits = model(mygpt.BracketedSequence(result)).x
                 dist = torch.distributions.categorical.Categorical(logits=logits)
                 pred_result = result.clone()
-                result[i] = dist.sample()[i]
-                result[:, 0 * L] = input[:, 0 * L]
-                result[:, 1 * L] = input[:, 1 * L]
-                result[:, 2 * L] = input[:, 2 * L]
-                result[:, 3 * L] = input[:, 3 * L]
+                result[i] = (1 - mask_generate) * input + (
+                    mask_generate * dist.sample()[i]
+                )
                 changed = (pred_result == result).long().min(dim=1).values == 0
                 i = i & changed
                 nb_it += 1