Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 07:44:29 +0000 (09:44 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 07:44:29 +0000 (09:44 +0200)
grids.py
main.py
mygpt.py

index 9e80f62..73e722e 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -709,20 +709,22 @@ class Grids(problem.Problem):
         nb_rec = 3
         c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
         for X, f_X in [(A, f_A), (B, f_B)]:
-            r = self.rec_coo(nb_rec, prevent_overlap=True)
+            while True:
+                r = self.rec_coo(nb_rec, prevent_overlap=True)
+                if min([x[2] for x in r]) > self.height // 2 + 1:
+                    break
             for n in range(nb_rec):
                 i1, j1, i2, j2 = r[n]
                 X[i1:i2, j1:j2] = c[n]
                 f_X[i1:i2, j1:j2] = c[n]
-            X[: self.height // 2] = c[-1]
+            X[: self.height // 2] = 0
             f_X[: self.height // 2] = f_X.flip([0])[: self.height // 2]
             if a == 1:
+                X[...] = X.flip((0,))
+                f_X[...] = f_X.flip((0,))
+            if b == 1:
                 X[...] = X.clone().t()
                 f_X[...] = f_X.clone().t()
-            if b == 1:
-                Z = X.clone()
-                X[...] = f_X
-                f_X[...] = Z
 
     # @torch.compile
     def task_translate(self, A, f_A, B, f_B):
diff --git a/main.py b/main.py
index 264b5c7..a4030ff 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -57,6 +57,10 @@ parser.add_argument("--nb_train_samples", type=int, default=25000)
 
 parser.add_argument("--nb_test_samples", type=int, default=1000)
 
+parser.add_argument("--nb_train_alien_samples", type=int, default=0)
+
+parser.add_argument("--nb_test_alien_samples", type=int, default=0)
+
 parser.add_argument("--nb_c_quizzes", type=int, default=2500)
 
 parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
@@ -304,7 +308,6 @@ alien_quiz_machine = quiz_machine.QuizMachine(
     logger=log_string,
     device=main_device,
 )
-
 # ------------------------------------------------------
 
 ######################################################################
@@ -366,121 +369,13 @@ def optimizer_to(optim, device):
                         subparam._grad.data = subparam._grad.data.to(device)
 
 
-######################################################################
-
-from mygpt import (
-    WithResidual,
-    CacheWrapper,
-    VaswaniPositionalEncoding,
-    TrainablePositionalEncoding,
-    QKVAttention,
-    BracketedSequence,
-)
-
-
-class Thinker(nn.Module):
-    def __init__(
-        self,
-        vocabulary_size,
-        dim_model,
-        dim_keys,
-        dim_hidden,
-        nb_heads,
-        nb_blocks,
-        f_len,
-        dropout=0.0,
-        len_max=1e5,
-    ):
-        super().__init__()
-
-        assert dim_model % nb_heads == 0
-
-        self.embedding = nn.Sequential(
-            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-            VaswaniPositionalEncoding(len_max),
-        )
-
-        def trunk(depth):
-            trunk_blocks = []
-
-            for b in range(nb_blocks):
-                trunk_blocks += [
-                    WithResidual(
-                        CacheWrapper(
-                            nn.LayerNorm((dim_model,)),
-                        ),
-                        QKVAttention(
-                            dim_in=dim_model,
-                            dim_qk=dim_keys,
-                            dim_v=dim_model // nb_heads,
-                            nb_heads=nb_heads,
-                            attention_dropout=dropout,
-                        ),
-                    ),
-                    WithResidual(
-                        CacheWrapper(
-                            nn.LayerNorm((dim_model,)),
-                            nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                            nn.ReLU(),
-                            nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                            nn.Dropout(dropout),
-                        ),
-                    ),
-                ]
-
-            return nn.Sequential(*trunk_blocks)
-
-        self.bottom_trunk = trunk(nb_blocks // 2)
-
-        self.top_trunk = trunk(nb_blocks // 2)
-
-        self.readout = CacheWrapper(
-            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
-        )
-
-        self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model))
-
-        with torch.no_grad():
-            for m in self.modules():
-                if isinstance(m, nn.Embedding):
-                    m.weight.normal_(mean=0, std=2e-2)
-                elif isinstance(m, nn.LayerNorm):
-                    m.bias.zero_()
-                    m.weight.fill_(1.0)
-
-    def forward(self, bs):
-        for m in self.modules():
-            m.loss = 0
-
-        L = bs.x.size(1) // 3
-
-        bs = self.embedding(bs)
-        A_fA = BracketedSequence(bs.x[:, : 2 * L])
-        B = BracketedSequence(bs.x[:, -L:])
-
-        bs = BracketedSequence(
-            torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1)
-        )
-        bs = self.bottom_trunk(bs)
-        bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1))
-        bs = self.top_trunk(bs)
-        bs = BracketedSequence(bs.x[:, f_len:, :])
-        bs = self.readout(bs)
-
-        for m in self.modules():
-            if m is not self:
-                self.loss += m.loss
-
-        return bs
-
-
 ######################################################################
 
 
 from mygpt import (
     WithResidual,
     CacheWrapper,
-    VaswaniPositionalEncoding,
+    CachedVaswaniPositionalEncoding,
     QKVAttention,
     BracketedSequence,
 )
@@ -548,7 +443,7 @@ class MyAttentionAE(nn.Module):
         )
 
         # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
-        self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
+        self.positional_encoding = CachedVaswaniPositionalEncoding(len_max=1e5)
 
         trunk_blocks = []
 
@@ -582,137 +477,6 @@ class MyAttentionAE(nn.Module):
         return bs
 
 
-######################################################################
-
-# f = phi(A, f(A)) + phi(B, f(B))
-# \hat{f(A)} = psi(A, f)
-# \hat{A} = psi_inv(f(A), f)
-# \hat{f(B)} = psi(B, f)
-# \hat{B} = psi_inv(f(B), f)
-
-
-def attention_layer(dim_model, dim_keys, nb_heads, dropout):
-    return WithResidual(
-        CacheWrapper(
-            nn.LayerNorm((dim_model,)),
-        ),
-        QKVAttention(
-            dim_in=dim_model,
-            dim_qk=dim_keys,
-            dim_v=dim_model // nb_heads,
-            nb_heads=nb_heads,
-            attention_dropout=dropout,
-        ),
-    )
-
-
-class FunctionalAE(nn.Module):
-    def __init__(
-        self,
-        vocabulary_size,
-        dim_model,
-        dim_keys,
-        dim_hidden,
-        nb_heads,
-        nb_blocks,
-        dropout=0.0,
-        len_max=1024,
-    ):
-        super().__init__()
-
-        assert dim_model % nb_heads == 0
-
-        self.embedding = CacheWrapper(
-            nn.Sequential(
-                MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout)
-            ),
-        )
-
-        # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
-        self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
-
-        def trunk(nb, bottom=True):
-            trunk_blocks = [VaswaniPositionalEncoding(len_max=1e5)]
-
-            la = [
-                QKVAttention(
-                    dim_in=dim_model,
-                    dim_qk=dim_keys,
-                    dim_v=dim_model // nb_heads,
-                    nb_heads=nb_heads,
-                    attention_dropout=dropout,
-                ),
-            ]
-
-            # if not bottom:
-            # trunk_blocks += la
-
-            for b in range(nb):
-                trunk_blocks += [
-                    attention_block(dim_model, dim_keys, nb_heads, dropout),
-                    ffw_block(dim_model, dim_hidden, nb_heads, dropout),
-                ]
-
-            # if bottom:
-            # trunk_blocks += la
-
-            return nn.Sequential(*trunk_blocks)
-
-        self.phi = trunk(nb_blocks // 2, bottom=True)
-        nb_f_tokens = 200
-        self.f_tokens = nn.Parameter(
-            torch.randn(1, nb_f_tokens, dim_model) / math.sqrt(nb_f_tokens)
-        )
-        self.psi = trunk(nb_blocks // 2, bottom=False)
-        self.psi_inv = trunk(nb_blocks // 2, bottom=False)
-        self.internal_pe = VaswaniPositionalEncoding(len_max=1e5)
-
-        self.readout = CacheWrapper(
-            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
-        )
-
-        with torch.no_grad():
-            for m in self.modules():
-                if isinstance(m, nn.Embedding):
-                    m.weight.normal_(mean=0, std=2e-2)
-                elif isinstance(m, nn.LayerNorm):
-                    m.bias.zero_()
-                    m.weight.fill_(1.0)
-
-    def forward(self, bs):
-        def cat(*x):
-            return BracketedSequence(torch.cat(x, dim=1))
-
-        if torch.is_tensor(bs):
-            return self.forward(BracketedSequence(bs)).x
-        bs = self.embedding(bs)
-        bs = self.positional_encoding(bs)
-
-        x_A, x_f_A, x_B, x_f_B = bs.x.chunk(4, dim=1)
-
-        K = self.f_tokens.size(1)
-        N, L = x_A.size()[:2]
-
-        ft = self.f_tokens.expand(N, -1, -1)
-
-        theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :]
-        theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :]
-
-        # if self.hook_theta is not None:
-        # self.hook_theta(theta_A, theta_B)
-
-        hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L]
-        hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L]
-
-        hat_A = self.psi_inv(cat(x_f_A, theta_B)).x[:, :L]
-        hat_B = self.psi_inv(cat(x_f_B, theta_A)).x[:, :L]
-
-        bs = cat(hat_A, hat_f_A, hat_B, hat_f_B)
-
-        bs = self.readout(bs)
-        return bs
-
-
 ######################################################################
 
 # quad_order, quad_generate, quad_noise, quad_loss
@@ -732,6 +496,8 @@ def ae_batches(
     data_structures,
     local_device,
     c_quizzes=None,
+    alien_quiz_machine=None,
+    nb_aliens=None,
     desc=None,
     batch_size=args.batch_size,
 ):
@@ -1149,24 +915,25 @@ def run_ae_test(
             f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
         )
 
-        model.test_accuracy = nb_correct / nb_total
-
         # Save some images
 
-        for f, record in [("prediction", record_d), ("generation", record_nd)]:
-            result, predicted_parts, correct_parts = bag_to_tensors(record)
+        if n_epoch < 50:
+            for f, record in [("prediction", record_d), ("generation", record_nd)]:
+                result, predicted_parts, correct_parts = bag_to_tensors(record)
 
-            filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+                filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
 
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir,
-                filename,
-                quizzes=result[:128],
-                predicted_parts=predicted_parts[:128],
-                correct_parts=correct_parts[:128],
-            )
+                quiz_machine.problem.save_quizzes_as_image(
+                    args.result_dir,
+                    filename,
+                    quizzes=result[:128],
+                    predicted_parts=predicted_parts[:128],
+                    correct_parts=correct_parts[:128],
+                )
 
-            log_string(f"wrote {filename}")
+                log_string(f"wrote {filename}")
+
+        return nb_correct / nb_total
 
 
 ######################################################################
@@ -1209,7 +976,19 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
     )
 
-    run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device)
+    model.test_accuracy = run_ae_test(
+        model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device
+    )
+
+    if args.nb_test_alien_samples > 0:
+        run_ae_test(
+            model,
+            alien_quiz_machine,
+            n_epoch,
+            c_quizzes=None,
+            local_device=local_device,
+            prefix="alien",
+        )
 
 
 ######################################################################
@@ -1308,6 +1087,10 @@ def quiz_validation(models, c_quizzes, local_device):
 
 def generate_ae_c_quizzes(models, nb, local_device=main_device):
     # To be thread-safe we must make copies
+
+    def copy_for_inference(model):
+        return copy.deepcopy(model).to(local_device).eval()
+
     quad_order = ("A", "f_A", "B", "f_B")
 
     template = quiz_machine.problem.create_empty_quizzes(
@@ -1318,9 +1101,6 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
         quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
     )
 
-    def copy_for_inference(model):
-        return copy.deepcopy(model).to(local_device).eval()
-
     wanted_nb = nb
     nb_to_save = 256
     nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device)
index a744224..5b56264 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -110,7 +110,7 @@ class CacheWrapper(nn.Module):
 ##############################
 
 
-class WithResidual(nn.Module):
+class CachedWithResidual(nn.Module):
     def __init__(self, *f):
         super().__init__()
         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
@@ -122,7 +122,7 @@ class WithResidual(nn.Module):
 ##############################
 
 
-class VaswaniPositionalEncoding(nn.Module):
+class CachedVaswaniPositionalEncoding(nn.Module):
     def __init__(self, len_max):
         super().__init__()
         self.len_max = len_max
@@ -358,13 +358,13 @@ class MyGPT(nn.Module):
             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
         )
 
-        self.positional_encoding = VaswaniPositionalEncoding(len_max)
+        self.positional_encoding = CachedVaswaniPositionalEncoding(len_max)
 
         trunk_blocks = []
 
         for b in range(nb_blocks):
             trunk_blocks += [
-                WithResidual(
+                CachedWithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
                         NoiseInjector(identifier=("attention", b)),
@@ -378,7 +378,7 @@ class MyGPT(nn.Module):
                         attention_dropout=dropout,
                     ),
                 ),
-                WithResidual(
+                CachedWithResidual(
                     CacheWrapper(
                         nn.LayerNorm((dim_model,)),
                         NoiseInjector(identifier=("ffw", b)),