Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 13 Sep 2024 09:27:03 +0000 (11:27 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 13 Sep 2024 09:27:03 +0000 (11:27 +0200)
attae.py
main.py

index bc90ed0..e201f60 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -16,17 +16,12 @@ class VaswaniPositionalEncoding(nn.Module):
         super().__init__()
         self.len_max = len_max
 
-    # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
-
     def forward(self, x):
         t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
         j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
-        k = j % 2
-
+        k = j % 2  # works with float, weird
         pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
-
         y = x + pe
-
         return y
 
 
@@ -45,23 +40,22 @@ class WithResidual(nn.Module):
 ######################################################################
 
 
-def vanilla_attention(q, k, v):
+def attention(q, k, v):
     a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
     a = a.softmax(dim=3)
     y = torch.einsum("nhts,nhsd->nhtd", a, v)
-    y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
     return y
 
 
-vanilla_attention = torch.compile(vanilla_attention)
+attention = torch.compile(attention)
 
-# y = flex_attention(q, k, v, score_mod=noop)
+######################################################################
 
 
 class MHAttention(nn.Module):
     def __init__(
         self,
-        dim_in,
+        dim_model,
         dim_qk,
         dim_v,
         nb_heads=1,
@@ -73,12 +67,10 @@ class MHAttention(nn.Module):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
         self.attention_dropout = attention_dropout
-        self.record_attention = False
-
-        self.w_q = randw(nb_heads, dim_qk, dim_in)
-        self.w_k = randw(nb_heads, dim_qk, dim_in)
-        self.w_v = randw(nb_heads, dim_v, dim_in)
-        self.w_o = randw(nb_heads, dim_v, dim_in)
+        self.w_q = randw(nb_heads, dim_qk, dim_model)
+        self.w_k = randw(nb_heads, dim_qk, dim_model)
+        self.w_v = randw(nb_heads, dim_v, dim_model)
+        self.w_o = randw(nb_heads, dim_v, dim_model)
 
     def forward(self, x_q, x_kv=None):
         if x_kv is None:
@@ -87,13 +79,7 @@ class MHAttention(nn.Module):
         q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
         k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
         v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
-
-        def noop(score, b, h, q_idx, kv_idx):
-            return score
-
-        y = vanilla_attention(q, k, v)
-        # y = flex_attention(q, k, v, score_mod=noop)
-
+        y = attention(q, k, v)
         y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
 
         return y
@@ -112,7 +98,7 @@ class AttentionAE(nn.Module):
         nb_heads,
         nb_blocks,
         dropout=0.0,
-        len_max=1024,
+        len_max=1e5,
     ):
         super().__init__()
 
@@ -123,7 +109,7 @@ class AttentionAE(nn.Module):
             nn.Dropout(dropout),
         )
 
-        self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
+        self.positional_encoding = VaswaniPositionalEncoding(len_max)
 
         trunk_blocks = []
 
@@ -132,7 +118,7 @@ class AttentionAE(nn.Module):
                 WithResidual(
                     nn.LayerNorm((dim_model,)),
                     MHAttention(
-                        dim_in=dim_model,
+                        dim_model=dim_model,
                         dim_qk=dim_keys,
                         dim_v=dim_model // nb_heads,
                         nb_heads=nb_heads,
diff --git a/main.py b/main.py
index 0fea318..e090f86 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1215,11 +1215,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
         c_quizzes = torch.cat(record_c_quizzes, dim=0)
         agreements = torch.cat(record_agreements, dim=0)
 
-    return c_quizzes, agreements
-
-
-def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
-    record.append(generate_ae_c_quizzes(models, nb, local_device))
+    return c_quizzes.to("cpu"), agreements.to("cpu")
 
 
 ######################################################################
@@ -1381,8 +1377,7 @@ def multithread_execution(fun, arguments):
 
     else:
         return [
-            torch.cat([x[k].to("cpu") for x in records], dim=0)
-            for k in range(len(records[0]))
+            torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
         ]