Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 26 Aug 2024 18:11:50 +0000 (20:11 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 26 Aug 2024 18:11:50 +0000 (20:11 +0200)
main.py

diff --git a/main.py b/main.py
index 3c8b4b9..eb0f776 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -892,6 +892,15 @@ def deterministic(mask_generate):
     return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
 
 
+def prioritized_rand(low):
+    x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+    k = torch.rand(low.size(), device=low.device) + low.long()
+    k = k.sort(dim=1).indices
+    y = x.new(x.size())
+    y.scatter_(dim=1, index=k, src=x)
+    return y
+
+
 def ae_generate(
     model, input, mask_generate, n_epoch, noise_proba, nb_iterations_max=50
 ):
@@ -909,20 +918,23 @@ def ae_generate(
         input_with_mask = NTC_channel_cat(input, mask_generate)
         logits = model(input_with_mask)
         dist = torch.distributions.categorical.Categorical(logits=logits)
+        final = dist.sample()
+
+        r = prioritized_rand(final != input)
 
-        r = torch.rand(mask_generate.size(), device=mask_generate.device)
         mask_erased = mask_generate * (r <= proba_erased).long()
         mask_to_change = d * mask_generate + (1 - d) * mask_erased
 
-        update = (1 - mask_to_change) * input + mask_to_change * dist.sample()
+        update = (1 - mask_to_change) * input + mask_to_change * final
 
         if update.equal(input):
-            log_string(f"converged at iteration {it}")
             break
         else:
             changed = changed & (update != input).max(dim=1).values
             input[changed] = update[changed]
 
+    log_string(f"remains {changed.long().sum()}")
+
     return input