Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 21:01:03 +0000 (23:01 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 21:01:03 +0000 (23:01 +0200)
attae.py
grids.py
main.py

index 06deed2..b4db3ab 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -10,7 +10,7 @@ import torch
 from torch import nn
 from torch.nn import functional as F
 
-# from torch.nn.attention.flex_attention import flex_attention
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
 
 ######################################################################
 
@@ -44,7 +44,7 @@ class WithResidual(nn.Module):
 ######################################################################
 
 
-def attention(q, k, v):
+def vanilla_attention(q, k, v):
     a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
     a = a.softmax(dim=3)
     y = torch.einsum("nhts,nhsd->nhtd", a, v)
@@ -61,6 +61,7 @@ class MHAttention(nn.Module):
         dim_qk,
         dim_v,
         nb_heads=1,
+        attention=vanilla_attention,
         attention_dropout=0.0,
     ):
         super().__init__()
@@ -68,6 +69,7 @@ class MHAttention(nn.Module):
         def randw(*d):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
+        self.attention = attention
         self.attention_dropout = attention_dropout
         self.w_q = randw(nb_heads, dim_qk, dim_model)
         self.w_k = randw(nb_heads, dim_qk, dim_model)
@@ -81,7 +83,7 @@ class MHAttention(nn.Module):
         q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
         k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
         v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
-        y = attention(q, k, v)
+        y = self.attention(q, k, v)
         y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
 
         return y
@@ -99,6 +101,7 @@ class AttentionAE(nn.Module):
         dim_hidden,
         nb_heads,
         nb_blocks,
+        attention=vanilla_attention,
         dropout=0.0,
         len_max=1e5,
     ):
@@ -159,7 +162,7 @@ class AttentionAE(nn.Module):
 ######################################################################
 
 
-class MaskedAttentionAE(nn.Module):
+class FunctionalAttentionAE(AttentionAE):
     def __init__(
         self,
         vocabulary_size,
@@ -167,13 +170,32 @@ class MaskedAttentionAE(nn.Module):
         dim_keys,
         dim_hidden,
         nb_heads,
+        nb_work_tokens,
         nb_blocks,
         dropout=0.0,
         len_max=1e5,
     ):
-        super().__init__()
-        self.core = AttentionAE(
-            vocabulary_size * 2,
+        # def functional_mask(b, h, q_idx, kv_idx):
+        # return (
+        # (q_idx < nb_work_tokens)
+        # | (kv_idx < nb_work_tokens)
+        # | ((q_idx - nb_work_tokens) // 200 == (kv_idx - nb_work_tokens) // 200)
+        # )
+
+        # block_mask = create_block_mask(
+        # functional_mask,
+        # B=None,
+        # H=None,
+        # Q_LEN=400 + nb_work_tokens,
+        # KV_LEN=400 + nb_work_tokens,
+        # )
+
+        # def functional_attention(q, k, v):
+        # return flex_attention(q, k, v, block_mask=block_mask)
+
+        AttentionAE.__init__(
+            self,
+            vocabulary_size,
             dim_model,
             dim_keys,
             dim_hidden,
@@ -182,22 +204,24 @@ class MaskedAttentionAE(nn.Module):
             dropout=0.0,
             len_max=1e5,
         )
+        self.nb_work_tokens = nb_work_tokens
 
     def forward(self, x):
-        x = x[:, :, 0] * 2 + x[:, :, 1]
-        return self.core(x)
+        x = torch.cat([x.new_zeros(x.size(0), self.nb_work_tokens), x], dim=1)
+        return AttentionAE.forward(self, x)[:, self.nb_work_tokens :]
 
 
 ######################################################################
 
 
 if __name__ == "__main__":
-    model = AttentionAE(
+    model = FunctionalAttentionAE(
         vocabulary_size=100,
         dim_model=16,
         dim_keys=64,
         dim_hidden=32,
         nb_heads=4,
+        nb_work_tokens=10,
         nb_blocks=4,
         dropout=0.1,
     )
index 6b2ea23..23a3d12 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -134,8 +134,16 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations):
 
 
 class Grids(problem.Problem):
+    # grid_gray=64
+    # thickness=1
+    # background_gray=255
+
+    grid_gray = 255
+    thickness = 0
+    background_gray = grid_gray
+
     named_colors = [
-        ("white", [255, 255, 255]),
+        ("white", [background_gray, background_gray, background_gray]),
         # ("white", [224, 224, 224]),
         ("red", [255, 0, 0]),
         ("green", [0, 192, 0]),
@@ -380,8 +388,9 @@ class Grids(problem.Problem):
         y = y.reshape(s[0], s[1], s[2] * scale, s[3] * scale)
 
         if grids:
-            y[:, :, :, torch.arange(0, y.size(3), scale)] = 64
-            y[:, :, torch.arange(0, y.size(2), scale), :] = 64
+            for t in range(self.thickness):
+                y[:, :, :, torch.arange(t, y.size(3), scale)] = self.grid_gray
+                y[:, :, torch.arange(t, y.size(2), scale), :] = self.grid_gray
 
         for n in range(m.size(0)):
             for i in range(m.size(1)):
@@ -463,11 +472,17 @@ class Grids(problem.Problem):
         )
 
         frame, white, gray, green, red = torch.tensor(
-            [[64, 64, 64], [255, 255, 255], [200, 200, 200], [0, 255, 0], [255, 0, 0]],
+            [
+                [self.grid_gray, self.grid_gray, self.grid_gray],
+                [255, 255, 255],
+                [200, 200, 200],
+                [0, 255, 0],
+                [255, 0, 0],
+            ],
             device=quizzes.device,
         )
 
-        thickness = 1 if grids else 0
+        thickness = self.thickness
 
         if delta:
             u = (A != f_A).long()
diff --git a/main.py b/main.py
index c7131c3..16edcdc 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -512,9 +512,9 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
     # Half of the samples train the prediction, and we inject noise in
     # all, and hints in half
     b_p = batch_for_prediction_imt(q_p)
-    i = torch.rand(b_p.size(0)) < 0.5
     b_p = add_noise_imt(b_p)
-    b_p[i] = add_hints_imt(b_p[i])
+    half = torch.rand(b_p.size(0)) < 0.5
+    b_p[half] = add_hints_imt(b_p[half])
 
     # The other half are denoising examples for the generation
     b_g = batch_for_generation_imt(q_g)
@@ -610,7 +610,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
     # Compute the test accuracy
 
-    nb_correct, nb_total = correct.sum(), quizzes.size(0)
+    nb_correct, nb_total = correct.sum().item(), quizzes.size(0)
     model.test_accuracy = nb_correct / nb_total
 
     log_string(
@@ -634,6 +634,17 @@ import attae
 models = []
 
 for i in range(args.nb_models):
+    # model = attae.FunctionalAttentionAE(
+    # vocabulary_size=vocabulary_size * 2,
+    # dim_model=args.dim_model,
+    # dim_keys=args.dim_keys,
+    # dim_hidden=args.dim_hidden,
+    # nb_heads=args.nb_heads,
+    # nb_blocks=args.nb_blocks,
+    # nb_work_tokens=10,
+    # dropout=args.dropout,
+    # )
+
     model = attae.AttentionAE(
         vocabulary_size=vocabulary_size * 2,
         dim_model=args.dim_model,
@@ -975,6 +986,10 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
 
+    log_string(
+        f"weakest_accuracies {[model.test_accuracy for model in weakest_models]}"
+    )
+
     multithread_execution(
         one_complete_epoch,
         [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],