Update. diverse
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 20:52:01 +0000 (22:52 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 20:52:01 +0000 (22:52 +0200)
main.py

diff --git a/main.py b/main.py
index 2f5e0dc..7ce9b03 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -872,10 +872,14 @@ def degrade_input_inplace(input, mask_generate, pure_noise=False):
         mask_diffusion_noise = mask_diffusion_noise.long()
 
         input[...] = (
-            1 - mask_generate
-        ) * input + mask_generate * mask_diffusion_noise * torch.randint(
-            quiz_machine.problem.nb_colors, input.size(), device=input.device
+            mask_generate
+            * mask_diffusion_noise
+            * torch.randint(
+                quiz_machine.problem.nb_colors, input.size(), device=input.device
+            )
+            + (1 - mask_generate * mask_diffusion_noise) * input
         )
+
     else:
         model.eval()
         for it in range(torch.randint(5, (1,)).item()):
@@ -944,6 +948,20 @@ def test_ae(local_device=main_device):
                     torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
                 )
             ).x
+
+            # for filename, quizzes in [
+            # ("targets.png", targets),
+            # ("input.png", input),
+            # ("mask_generate.png", mask_generate),
+            # ("mask_loss.png", mask_loss),
+            # ]:
+            # quiz_machine.problem.save_quizzes_as_image(
+            # args.result_dir,
+            # filename,
+            # quizzes=quizzes,
+            # )
+            # time.sleep(10000)
+
             loss_per_token = F.cross_entropy(
                 output.transpose(1, 2), targets, reduction="none"
             )