Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 07:44:33 +0000 (09:44 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 07:44:33 +0000 (09:44 +0200)
attae.py [new file with mode: 0755]

diff --git a/attae.py b/attae.py
new file mode 100755 (executable)
index 0000000..3a9f105
--- /dev/null
+++ b/attae.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python
+
+import math
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.attention.flex_attention import flex_attention
+
+######################################################################
+
+
+class VaswaniPositionalEncoding(nn.Module):
+    def __init__(self, len_max):
+        super().__init__()
+        self.len_max = len_max
+
+    # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D}))
+
+    def forward(self, x):
+        t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None]
+        j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :]
+        k = j % 2
+
+        pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k)
+
+        y = x + pe
+
+        return y
+
+
+######################################################################
+
+
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, x):
+        return x + self.f(x)
+
+
+######################################################################
+
+
+class MHAttention(nn.Module):
+    def __init__(
+        self,
+        dim_in,
+        dim_qk,
+        dim_v,
+        nb_heads=1,
+        attention_dropout=0.0,
+    ):
+        super().__init__()
+
+        def randw(*d):
+            return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
+
+        self.attention_dropout = attention_dropout
+        self.record_attention = False
+
+        self.w_q = randw(nb_heads, dim_qk, dim_in)
+        self.w_k = randw(nb_heads, dim_qk, dim_in)
+        self.w_v = randw(nb_heads, dim_v, dim_in)
+        self.w_o = randw(nb_heads, dim_v, dim_in)
+
+    def forward(self, x_q, x_kv=None):
+        if x_kv is None:
+            x_kv = x_q
+
+        q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
+        k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k)
+        v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v)
+
+        y = flex_attention(q, k, v)
+
+        y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+        return y
+
+
+######################################################################
+
+
+class AttentionAE(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1024,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Dropout(dropout),
+        )
+
+        self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
+
+        trunk_blocks = []
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    MHAttention(
+                        dim_in=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithResidual(
+                    nn.LayerNorm((dim_model,)),
+                    nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                    nn.ReLU(),
+                    nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                    nn.Dropout(dropout),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, x, mask=None):
+        x = self.embedding(x)
+        x = self.positional_encoding(x)
+        x = self.trunk(x)
+        x = self.readout(x)
+        return x
+
+
+######################################################################
+
+
+if __name__ == "__main__":
+    model = AttentionAE(
+        vocabulary_size=100,
+        dim_model=16,
+        dim_keys=64,
+        dim_hidden=32,
+        nb_heads=4,
+        nb_blocks=4,
+        dropout=0.1,
+    )
+
+    x = torch.randint(100, (10, 50))
+
+    y = model(x)