Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 12:57:41 +0000 (14:57 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 12:57:41 +0000 (14:57 +0200)
main.py

diff --git a/main.py b/main.py
index 289bae4..c6d76ee 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -750,6 +750,18 @@ from mygpt import (
 )
 
 
+class MultiEmbedding(nn.Module):
+    def __init__(self, nb_values, dim):
+        super().__init__()
+        self.embeddings = nn.ModuleList([nn.Embedding(n, dim) for n in nb_values])
+
+    def forward(self, x):
+        y = 0
+        for f, z in zip(self.embeddings, x.split(1, dim=2)):
+            y = y + f(z[:, :, 0])
+        return y
+
+
 class MyAttentionAE(nn.Module):
     def __init__(
         self,
@@ -766,11 +778,14 @@ class MyAttentionAE(nn.Module):
 
         assert dim_model % nb_heads == 0
 
-        self.embedding = nn.Sequential(
-            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+        self.embedding = CacheWrapper(
+            nn.Sequential(
+                MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout)
+            ),
         )
 
-        self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
+        # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
+        self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
 
         trunk_blocks = []
 
@@ -859,12 +874,14 @@ def test_ae(local_device=main_device):
         dropout=args.dropout,
     ).to(main_device)
 
+    pure_noise = True
+
     data_structures = [
-        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-        (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (1, 1, 1, 1)),
-        (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (1, 1, 1, 1)),
-        (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 1, 1, 1)),
-        (("A", "f_A", "B", "f_B"), (1, 1, 1, 0), (0, 0, 0, 0), (1, 1, 1, 1)),
+        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+        (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
+        (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
+        (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
+        (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
     ]
 
     model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
@@ -894,30 +911,37 @@ def test_ae(local_device=main_device):
 
             targets = input
 
-            # mask_diffusion_noise = (mask_generate == 1) & (
-            # torch.rand(mask_generate.size(), device=mask_generate.device)
-            # <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
-            # )
+            input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
 
-            # mask_diffusion_noise = mask_diffusion_noise.long()
+            if pure_noise:
+                mask_diffusion_noise = torch.rand(
+                    mask_generate.size(), device=mask_generate.device
+                ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
 
-            # input = (
-            # 1 - mask_diffusion_noise
-            # ) * input + mask_diffusion_noise * torch.randint(
-            # quiz_machine.problem.nb_colors, input.size(), device=input.device
-            # )
+                mask_diffusion_noise = mask_diffusion_noise.long()
 
-            # ------------------------------
-            input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
-            model.eval()
-            for it in range(torch.randint(5, (1,)).item()):
-                logits = model(mygpt.BracketedSequence(input)).x
-                dist = torch.distributions.categorical.Categorical(logits=logits)
-                input = (1 - mask_generate) * input + mask_generate * dist.sample()
-            model.train()
-            # -----------------------------
+                input = input + mask_generate * mask_diffusion_noise * torch.randint(
+                    quiz_machine.problem.nb_colors, input.size(), device=input.device
+                )
+            else:
+                model.eval()
+                for it in range(torch.randint(5, (1,)).item()):
+                    logits = model(
+                        mygpt.BracketedSequence(
+                            torch.cat(
+                                [input[:, :, None], mask_generate[:, :, None]], dim=2
+                            )
+                        )
+                    ).x
+                    dist = torch.distributions.categorical.Categorical(logits=logits)
+                    input = (1 - mask_generate) * input + mask_generate * dist.sample()
+                model.train()
 
-            output = model(mygpt.BracketedSequence(input)).x
+            output = model(
+                mygpt.BracketedSequence(
+                    torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
+                )
+            ).x
             loss = F.cross_entropy(output.transpose(1, 2), targets)
             acc_train_loss += loss.item() * input.size(0)
             nb_train_samples += input.size(0)
@@ -947,31 +971,48 @@ def test_ae(local_device=main_device):
             ):
                 targets = input
 
-                # mask_diffusion_noise = (mask_generate == 1) & (
-                # torch.rand(mask_generate.size(), device=mask_generate.device)
-                # <= torch.rand(
-                # (mask_generate.size(0), 1), device=mask_generate.device
-                # )
-                # )
-
-                # mask_diffusion_noise = mask_diffusion_noise.long()
-
-                # input = (
-                # 1 - mask_diffusion_noise
-                # ) * input + mask_diffusion_noise * torch.randint(
-                # quiz_machine.problem.nb_colors, input.size(), device=input.device
-                # )
-
-                # ------------------------------
                 input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
 
-                for it in range(torch.randint(5, (1,)).item()):
-                    logits = model(mygpt.BracketedSequence(input)).x
-                    dist = torch.distributions.categorical.Categorical(logits=logits)
-                    input = (1 - mask_generate) * input + mask_generate * dist.sample()
-                # -----------------------------
-
-                output = model(mygpt.BracketedSequence(input)).x
+                if pure_noise:
+                    mask_diffusion_noise = torch.rand(
+                        mask_generate.size(), device=mask_generate.device
+                    ) <= torch.rand(
+                        (mask_generate.size(0), 1), device=mask_generate.device
+                    )
+
+                    mask_diffusion_noise = mask_diffusion_noise.long()
+
+                    input = (
+                        input
+                        + mask_generate
+                        * mask_diffusion_noise
+                        * torch.randint(
+                            quiz_machine.problem.nb_colors,
+                            input.size(),
+                            device=input.device,
+                        )
+                    )
+                else:
+                    for it in range(torch.randint(5, (1,)).item()):
+                        logits = model(
+                            mygpt.BracketedSequence(
+                                torch.cat(
+                                    [input[:, None], mask_generate[:, None]], dim=1
+                                )
+                            )
+                        ).x
+                        dist = torch.distributions.categorical.Categorical(
+                            logits=logits
+                        )
+                        input = (
+                            1 - mask_generate
+                        ) * input + mask_generate * dist.sample()
+
+                output = model(
+                    mygpt.BracketedSequence(
+                        torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
+                    )
+                ).x
                 loss = F.cross_entropy(output.transpose(1, 2), targets)
                 acc_test_loss += loss.item() * input.size(0)
                 nb_test_samples += input.size(0)
@@ -992,11 +1033,43 @@ def test_ae(local_device=main_device):
 
                 input = (1 - mask_generate) * input  # PARANOIAAAAAAAAAA
 
-                result = (1 - mask_generate) * input
+                if pure_noise:
+                    mask_diffusion_noise = torch.rand(
+                        mask_generate.size(), device=mask_generate.device
+                    ) <= torch.rand(
+                        (mask_generate.size(0), 1), device=mask_generate.device
+                    )
+
+                    mask_diffusion_noise = mask_diffusion_noise.long()
+
+                    input = (
+                        input
+                        + mask_generate
+                        * mask_diffusion_noise
+                        * torch.randint(
+                            quiz_machine.problem.nb_colors,
+                            input.size(),
+                            device=input.device,
+                        )
+                    )
+                else:
+                    for it in range(torch.randint(5, (1,)).item()):
+                        logits = model(
+                            mygpt.BracketedSequence(
+                                torch.cat(
+                                    [input[:, :, None], mask_generate[:, :, None]],
+                                    dim=2,
+                                )
+                            )
+                        ).x
+                        dist = torch.distributions.categorical.Categorical(
+                            logits=logits
+                        )
+                        input = (
+                            1 - mask_generate
+                        ) * input + mask_generate * dist.sample()
 
-                # + mask_generate * torch.randint(
-                # quiz_machine.problem.nb_colors, input.size(), device=input.device
-                # )
+                result = input
 
                 not_converged = torch.full(
                     (result.size(0),), True, device=result.device
@@ -1004,7 +1077,17 @@ def test_ae(local_device=main_device):
 
                 for it in range(100):
                     pred_result = result.clone()
-                    logits = model(mygpt.BracketedSequence(result[not_converged])).x
+                    logits = model(
+                        mygpt.BracketedSequence(
+                            torch.cat(
+                                [
+                                    result[not_converged, :, None],
+                                    mask_generate[:, :, None],
+                                ],
+                                dim=2,
+                            )
+                        )
+                    ).x
                     dist = torch.distributions.categorical.Categorical(logits=logits)
                     update = (1 - mask_generate[not_converged]) * input[
                         not_converged