Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 06:56:35 +0000 (08:56 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 06:56:35 +0000 (08:56 +0200)
main.py

diff --git a/main.py b/main.py
index e7dd337..289bae4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -894,18 +894,28 @@ def test_ae(local_device=main_device):
 
             targets = input
 
-            mask_diffusion_noise = (mask_generate == 1) & (
-                torch.rand(mask_generate.size(), device=mask_generate.device)
-                <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
-            )
+            mask_diffusion_noise = (mask_generate == 1) & (
+            # torch.rand(mask_generate.size(), device=mask_generate.device)
+            # <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+            )
 
-            mask_diffusion_noise = mask_diffusion_noise.long()
+            mask_diffusion_noise = mask_diffusion_noise.long()
 
-            input = (
-                1 - mask_diffusion_noise
-            ) * input + mask_diffusion_noise * torch.randint(
-                quiz_machine.problem.nb_colors, input.size(), device=input.device
-            )
+            # input = (
+            # 1 - mask_diffusion_noise
+            # ) * input + mask_diffusion_noise * torch.randint(
+            # quiz_machine.problem.nb_colors, input.size(), device=input.device
+            # )
+
+            # ------------------------------
+            input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
+            model.eval()
+            for it in range(torch.randint(5, (1,)).item()):
+                logits = model(mygpt.BracketedSequence(input)).x
+                dist = torch.distributions.categorical.Categorical(logits=logits)
+                input = (1 - mask_generate) * input + mask_generate * dist.sample()
+            model.train()
+            # -----------------------------
 
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), targets)
@@ -937,20 +947,29 @@ def test_ae(local_device=main_device):
             ):
                 targets = input
 
-                mask_diffusion_noise = (mask_generate == 1) & (
-                    torch.rand(mask_generate.size(), device=mask_generate.device)
-                    <= torch.rand(
-                        (mask_generate.size(0), 1), device=mask_generate.device
-                    )
-                )
+                mask_diffusion_noise = (mask_generate == 1) & (
+                # torch.rand(mask_generate.size(), device=mask_generate.device)
+                # <= torch.rand(
+                # (mask_generate.size(0), 1), device=mask_generate.device
+                # )
+                )
 
-                mask_diffusion_noise = mask_diffusion_noise.long()
+                mask_diffusion_noise = mask_diffusion_noise.long()
 
-                input = (
-                    1 - mask_diffusion_noise
-                ) * input + mask_diffusion_noise * torch.randint(
-                    quiz_machine.problem.nb_colors, input.size(), device=input.device
-                )
+                # input = (
+                # 1 - mask_diffusion_noise
+                # ) * input + mask_diffusion_noise * torch.randint(
+                # quiz_machine.problem.nb_colors, input.size(), device=input.device
+                # )
+
+                # ------------------------------
+                input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
+
+                for it in range(torch.randint(5, (1,)).item()):
+                    logits = model(mygpt.BracketedSequence(input)).x
+                    dist = torch.distributions.categorical.Categorical(logits=logits)
+                    input = (1 - mask_generate) * input + mask_generate * dist.sample()
+                # -----------------------------
 
                 output = model(mygpt.BracketedSequence(input)).x
                 loss = F.cross_entropy(output.transpose(1, 2), targets)
@@ -973,26 +992,27 @@ def test_ae(local_device=main_device):
 
                 input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
 
-                result = (1 - mask_generate) * input + mask_generate * torch.randint(
-                    quiz_machine.problem.nb_colors, input.size(), device=input.device
-                )
+                result = (1 - mask_generate) * input
+
+                # + mask_generate * torch.randint(
+                # quiz_machine.problem.nb_colors, input.size(), device=input.device
+                # )
 
                 not_converged = torch.full(
                     (result.size(0),), True, device=result.device
                 )
 
-                nb_it = 0
-
-                while True:
-                    logits = model(mygpt.BracketedSequence(result)).x
-                    dist = torch.distributions.categorical.Categorical(logits=logits)
+                for it in range(100):
                     pred_result = result.clone()
-                    update = (1 - mask_generate) * input + mask_generate * dist.sample()
-                    result[not_converged] = update[not_converged]
+                    logits = model(mygpt.BracketedSequence(result[not_converged])).x
+                    dist = torch.distributions.categorical.Categorical(logits=logits)
+                    update = (1 - mask_generate[not_converged]) * input[
+                        not_converged
+                    ] + mask_generate[not_converged] * dist.sample()
+                    result[not_converged] = update
                     not_converged = (pred_result != result).max(dim=1).values
-                    nb_it += 1
-                    print("DEBUG", nb_it, not_converged.long().sum().item())
-                    if not not_converged.any() or nb_it > 100:
+                    if not not_converged.any():
+                        log_string(f"diffusion_converged {it=}")
                         break
 
                 correct = (result == targets).min(dim=1).values.long()