Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 13 Sep 2024 09:45:11 +0000 (11:45 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 13 Sep 2024 09:45:11 +0000 (11:45 +0200)
main.py

diff --git a/main.py b/main.py
index 92a34f1..5c086cf 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -731,10 +731,12 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None
 
     single_iteration = deterministic(mask_generate)[:, None]
 
-    if mask_hints is not None:
-        mask_generate = mask_generate * (1 - mask_hints)
+    if mask_hints is None:
+        mask_start = mask_generate
+    else:
+        mask_start = mask_generate * (1 - mask_hints)
 
-    x_t = (1 - mask_generate) * x_0 + mask_generate * noise
+    x_t = (1 - mask_start) * x_0 + mask_start * noise
 
     changed = True