Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 16:35:53 +0000 (18:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 16:35:53 +0000 (18:35 +0200)
main.py

diff --git a/main.py b/main.py
index c6d76ee..2f5e0dc 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -863,6 +863,32 @@ def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
         )
 
 
+def degrade_input_inplace(input, mask_generate, pure_noise=False):
+    if pure_noise:
+        mask_diffusion_noise = torch.rand(
+            mask_generate.size(), device=mask_generate.device
+        ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+
+        mask_diffusion_noise = mask_diffusion_noise.long()
+
+        input[...] = (
+            1 - mask_generate
+        ) * input + mask_generate * mask_diffusion_noise * torch.randint(
+            quiz_machine.problem.nb_colors, input.size(), device=input.device
+        )
+    else:
+        model.eval()
+        for it in range(torch.randint(5, (1,)).item()):
+            logits = model(
+                mygpt.BracketedSequence(
+                    torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
+                )
+            ).x
+            dist = torch.distributions.categorical.Categorical(logits=logits)
+            input[...] = (1 - mask_generate) * input + mask_generate * dist.sample()
+        model.train()
+
+
 def test_ae(local_device=main_device):
     model = MyAttentionAE(
         vocabulary_size=vocabulary_size,
@@ -876,6 +902,7 @@ def test_ae(local_device=main_device):
 
     pure_noise = True
 
+    # quad_order, quad_generate, quad_noise, quad_loss
     data_structures = [
         (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
         (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
@@ -909,40 +936,18 @@ def test_ae(local_device=main_device):
             if nb_train_samples % args.batch_size == 0:
                 model.optimizer.zero_grad()
 
-            targets = input
-
-            input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
-
-            if pure_noise:
-                mask_diffusion_noise = torch.rand(
-                    mask_generate.size(), device=mask_generate.device
-                ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
-
-                mask_diffusion_noise = mask_diffusion_noise.long()
-
-                input = input + mask_generate * mask_diffusion_noise * torch.randint(
-                    quiz_machine.problem.nb_colors, input.size(), device=input.device
-                )
-            else:
-                model.eval()
-                for it in range(torch.randint(5, (1,)).item()):
-                    logits = model(
-                        mygpt.BracketedSequence(
-                            torch.cat(
-                                [input[:, :, None], mask_generate[:, :, None]], dim=2
-                            )
-                        )
-                    ).x
-                    dist = torch.distributions.categorical.Categorical(logits=logits)
-                    input = (1 - mask_generate) * input + mask_generate * dist.sample()
-                model.train()
+            targets = input.clone()
+            degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
 
             output = model(
                 mygpt.BracketedSequence(
                     torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
                 )
             ).x
-            loss = F.cross_entropy(output.transpose(1, 2), targets)
+            loss_per_token = F.cross_entropy(
+                output.transpose(1, 2), targets, reduction="none"
+            )
+            loss = (loss_per_token * mask_loss).mean()
             acc_train_loss += loss.item() * input.size(0)
             nb_train_samples += input.size(0)
             loss.backward()
@@ -969,51 +974,17 @@ def test_ae(local_device=main_device):
                 local_device,
                 "test",
             ):
-                targets = input
-
-                input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
-
-                if pure_noise:
-                    mask_diffusion_noise = torch.rand(
-                        mask_generate.size(), device=mask_generate.device
-                    ) <= torch.rand(
-                        (mask_generate.size(0), 1), device=mask_generate.device
-                    )
-
-                    mask_diffusion_noise = mask_diffusion_noise.long()
-
-                    input = (
-                        input
-                        + mask_generate
-                        * mask_diffusion_noise
-                        * torch.randint(
-                            quiz_machine.problem.nb_colors,
-                            input.size(),
-                            device=input.device,
-                        )
-                    )
-                else:
-                    for it in range(torch.randint(5, (1,)).item()):
-                        logits = model(
-                            mygpt.BracketedSequence(
-                                torch.cat(
-                                    [input[:, None], mask_generate[:, None]], dim=1
-                                )
-                            )
-                        ).x
-                        dist = torch.distributions.categorical.Categorical(
-                            logits=logits
-                        )
-                        input = (
-                            1 - mask_generate
-                        ) * input + mask_generate * dist.sample()
-
+                targets = input.clone()
+                degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
                 output = model(
                     mygpt.BracketedSequence(
                         torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
                     )
                 ).x
-                loss = F.cross_entropy(output.transpose(1, 2), targets)
+                loss_per_token = F.cross_entropy(
+                    output.transpose(1, 2), targets, reduction="none"
+                )
+                loss = (loss_per_token * mask_loss).mean()
                 acc_test_loss += loss.item() * input.size(0)
                 nb_test_samples += input.size(0)
 
@@ -1029,46 +1000,8 @@ def test_ae(local_device=main_device):
                     ae_batches(quiz_machine, 128, [s], local_device)
                 )
 
-                targets = input
-
-                input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
-
-                if pure_noise:
-                    mask_diffusion_noise = torch.rand(
-                        mask_generate.size(), device=mask_generate.device
-                    ) <= torch.rand(
-                        (mask_generate.size(0), 1), device=mask_generate.device
-                    )
-
-                    mask_diffusion_noise = mask_diffusion_noise.long()
-
-                    input = (
-                        input
-                        + mask_generate
-                        * mask_diffusion_noise
-                        * torch.randint(
-                            quiz_machine.problem.nb_colors,
-                            input.size(),
-                            device=input.device,
-                        )
-                    )
-                else:
-                    for it in range(torch.randint(5, (1,)).item()):
-                        logits = model(
-                            mygpt.BracketedSequence(
-                                torch.cat(
-                                    [input[:, :, None], mask_generate[:, :, None]],
-                                    dim=2,
-                                )
-                            )
-                        ).x
-                        dist = torch.distributions.categorical.Categorical(
-                            logits=logits
-                        )
-                        input = (
-                            1 - mask_generate
-                        ) * input + mask_generate * dist.sample()
-
+                targets = input.clone()
+                degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
                 result = input
 
                 not_converged = torch.full(
@@ -1082,7 +1015,7 @@ def test_ae(local_device=main_device):
                             torch.cat(
                                 [
                                     result[not_converged, :, None],
-                                    mask_generate[:, :, None],
+                                    mask_generate[not_converged, :, None],
                                 ],
                                 dim=2,
                             )