From: François Fleuret Date: Sun, 15 Sep 2024 20:20:25 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=43c156af08ce88a262fd3be6f09498abf3e43bad;p=culture.git Update. --- diff --git a/diffusion.py b/diffusion.py index abe8986..629113a 100755 --- a/diffusion.py +++ b/diffusion.py @@ -52,9 +52,13 @@ class Diffuser: ###################################################################### - def make_mask_hints(mask_generate, nb_hints): + def make_mask_hints(self, mask_generate, nb_hints): if nb_hints is None: - mask_hints = None + mask_hints = torch.zeros( + mask_generate.size(), + device=mask_generate.device, + dtype=mask_generate.dtype, + ) else: u = ( torch.rand(mask_generate.size(), device=mask_generate.device) @@ -94,7 +98,7 @@ class Diffuser: t = dist.sample() + 1 - x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise + x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise x_t = self.sample_x_t_given_x_0(x_0, t) x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t x_t = (1 - mask_generate) * x_0 + mask_generate * x_t @@ -129,9 +133,8 @@ class Diffuser: mask_hints = self.make_mask_hints(mask_generate, nb_hints) - x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise - x_t = self.sample_x_t_given_x_0(x_0, t) - x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t + x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise + x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * noise x_t = (1 - mask_generate) * x_0 + mask_generate * x_t changed = True