Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 20:20:25 +0000 (22:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 15 Sep 2024 20:20:25 +0000 (22:20 +0200)
diffusion.py

index abe8986..629113a 100755 (executable)
@@ -52,9 +52,13 @@ class Diffuser:
 
     ######################################################################
 
-    def make_mask_hints(mask_generate, nb_hints):
+    def make_mask_hints(self, mask_generate, nb_hints):
         if nb_hints is None:
-            mask_hints = None
+            mask_hints = torch.zeros(
+                mask_generate.size(),
+                device=mask_generate.device,
+                dtype=mask_generate.dtype,
+            )
         else:
             u = (
                 torch.rand(mask_generate.size(), device=mask_generate.device)
@@ -94,7 +98,7 @@ class Diffuser:
 
         t = dist.sample() + 1
 
-        x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise
+        x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise
         x_t = self.sample_x_t_given_x_0(x_0, t)
         x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
         x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
@@ -129,9 +133,8 @@ class Diffuser:
 
         mask_hints = self.make_mask_hints(mask_generate, nb_hints)
 
-        x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise
-        x_t = self.sample_x_t_given_x_0(x_0, t)
-        x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
+        x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise
+        x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * noise
         x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
 
         changed = True