Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:32:07 +0000 (07:32 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:32:07 +0000 (07:32 +0200)
main.py

diff --git a/main.py b/main.py
index 2f867db..00a6cd1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -981,10 +981,9 @@ def test_ae(local_device=main_device):
                 logits = model(mygpt.BracketedSequence(result)).x
                 dist = torch.distributions.categorical.Categorical(logits=logits)
                 pred_result = result.clone()
-                result[not_converged] = (
-                    (1 - mask_generate) * input + mask_generate * dist.sample()
-                )[not_converged]
-                not_converged = (pred_result == result).long().min(dim=1).values == 0
+                update = (1 - mask_generate) * input + mask_generate * dist.sample()
+                result[not_converged] = update[not_converged]
+                not_converged = (pred_result != result).max(dim=1).values
                 nb_it += 1
                 print("DEBUG", nb_it, i.long().sum().item())
                 if not i.any() or nb_it > 100: