Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:29:48 +0000 (07:29 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:29:48 +0000 (07:29 +0200)
main.py

diff --git a/main.py b/main.py
index 0fe33f6..2f867db 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -969,12 +969,11 @@ def test_ae(local_device=main_device):
             targets = input
 
             input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
-            pred_result = None
             result = (1 - mask_generate) * input + mask_generate * torch.randint(
                 quiz_machine.problem.nb_colors, input.size(), device=input.device
             )
 
-            i = torch.full((result.size(0),), True, device=result.device)
+            not_converged = torch.full((result.size(0),), True, device=result.device)
 
             nb_it = 0
 
@@ -982,11 +981,10 @@ def test_ae(local_device=main_device):
                 logits = model(mygpt.BracketedSequence(result)).x
                 dist = torch.distributions.categorical.Categorical(logits=logits)
                 pred_result = result.clone()
-                result[i] = (1 - mask_generate[i]) * input + (
-                    mask_generate * dist.sample()[i]
-                )
-                changed = (pred_result == result).long().min(dim=1).values == 0
-                i = i & changed
+                result[not_converged] = (
+                    (1 - mask_generate) * input + mask_generate * dist.sample()
+                )[not_converged]
+                not_converged = (pred_result == result).long().min(dim=1).values == 0
                 nb_it += 1
                 print("DEBUG", nb_it, i.long().sum().item())
                 if not i.any() or nb_it > 100: