Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 15:19:55 +0000 (17:19 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 15:19:55 +0000 (17:19 +0200)
main.py
mygpt.py

diff --git a/main.py b/main.py
index 046514d..d98031e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -342,7 +342,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
         nb_samples_accumulated = 0
 
         full_input, full_mask_loss = quiz_machine.data_input(
-            args.nb_test_samples, model.test_c_quiz_bags
+            args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier
         )
         src = zip(
             full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
@@ -370,10 +370,14 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
 
+        input, _ = quiz_machine.data_input(
+            2000, model.test_c_quiz_bags, args.c_quiz_multiplier
+        )
+
         model.test_accuracy = quiz_machine.produce_results(
             n_epoch=n_epoch,
             model=model,
-            input=full_input[:2000],
+            input=input,
             result_dir=args.result_dir,
         )
 
index 041d28c..f716fe5 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -201,42 +201,48 @@ class QKVAttention(nn.Module):
         self.w_v = randw(nb_heads, dim_v, dim_in)
         self.w_o = randw(dim_v * nb_heads, dim_in)
 
-    def forward(self, bs_q):
+    def forward(self, bs_q, bs_kv=None):
+        if bs_kv is None:
+            bs_kv = bs_q
+
         x_q = bs_q.x
+        x_kv = bs_kv.x
 
-        if bs_q.first == 0:
-            self.cache_k = x_q.new_zeros(
-                x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1)
+        if bs_kv.first == 0:
+            self.cache_k = x_kv.new_zeros(
+                x_kv.size(0), self.w_k.size(0), x_kv.size(1), self.w_k.size(1)
             )
-            self.cache_v = x_q.new_zeros(
-                x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1)
+            self.cache_v = x_kv.new_zeros(
+                x_kv.size(0), self.w_v.size(0), x_kv.size(1), self.w_v.size(1)
             )
+
+        if bs_q.first == 0:
             self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1))
 
         q = torch.einsum(
             "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q
         )
 
-        self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
-            "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k
+        self.cache_k[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum(
+            "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_k
         )
-        self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum(
-            "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v
+        self.cache_v[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum(
+            "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_v
         )
 
         a = torch.einsum(
-            "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb]
+            "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_kv.first + bs_kv.nb]
         ) / math.sqrt(self.w_q.size(1))
 
         if self.compute_attzero is not None:
             if bs_q.first == 0:
                 self.cache_attzero = self.compute_attzero(
                     torch.arange(x_q.size(1), device=q.device)[:, None],
-                    torch.arange(x_q.size(1), device=q.device)[None, :],
+                    torch.arange(x_kv.size(1), device=q.device)[None, :],
                 )[None, None, :, :]
             a = a.masked_fill(
                 self.cache_attzero[
-                    :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb
+                    :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_kv.first + bs_kv.nb
                 ],
                 float("-inf"),
             )
@@ -249,7 +255,7 @@ class QKVAttention(nn.Module):
         a = F.dropout(a, self.attention_dropout, self.training)
 
         y = torch.einsum(
-            "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb]
+            "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_kv.first + bs_kv.nb]
         ).flatten(2)
 
         self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o
@@ -277,6 +283,36 @@ class NoiseInjector(nn.Module):
 ##############################
 
 
+class BlockSummarizer(nn.Module):
+    def __init__(self, nb_blocks, nb_tokens, dim_keys, dim_model):
+        self.nb_blocks = nb_blocks
+        self.static_q = nn.Parameter(nb_blocks - 1, nb_tokens, dim_keys)
+
+        def compute_block_attzero(t_q, t_k):
+            block_size = t_q.size(0)
+            return (t_q // block_size) <= (t_k // block_size)
+
+        self.qkv = QKVAttention(
+            dim_in=dim_model,
+            dim_qk=dim_keys,
+            dim_v=dim_model // nb_heads,
+            nb_heads=nb_heads,
+            compute_attzero=compute_attzero,
+            attention_dropout=dropout,
+        )
+
+    def forward(self, bs):
+        pass
+
+
+class ShiftByOne(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, bs):
+        return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+
+
 class MyGPT(nn.Module):
     def __init__(
         self,
@@ -287,7 +323,6 @@ class MyGPT(nn.Module):
         nb_heads,
         nb_blocks,
         compute_attzero=None,
-        autoencoder_dim=-1,
         dropout=0.0,
         len_max=1e5,
     ):
@@ -297,11 +332,14 @@ class MyGPT(nn.Module):
 
         self.temperature = 1.0
 
+        self.shifter = ShiftByOne()
+
         self.embedding = nn.Sequential(
             CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-            AddPositionalEncoding(len_max),
         )
 
+        self.positional_encoding = AddPositionalEncoding(len_max)
+
         trunk_blocks = []
 
         for b in range(nb_blocks):
@@ -338,26 +376,6 @@ class MyGPT(nn.Module):
             nn.Linear(in_features=dim_model, out_features=vocabulary_size)
         )
 
-        # -------------------------------------------------------
-        if autoencoder_dim > 0:
-            self.encoder = nn.Sequential(
-                *(
-                    trunk_blocks[: nb_blocks // 2]
-                    + [EncoderHead(dim_model, autoencoder_dim)]
-                )
-            )
-
-            self.decoder = nn.Sequential(
-                *(
-                    [
-                        DecoderBottom(autoencoder_dim, dim_model),
-                        AddPositionalEncoding(len_max),
-                    ]
-                    + trunk_blocks[nb_blocks // 2 :]
-                )
-            )
-        # -------------------------------------------------------
-
         with torch.no_grad():
             for m in self.modules():
                 if isinstance(m, nn.Embedding):
@@ -370,8 +388,9 @@ class MyGPT(nn.Module):
         for m in self.modules():
             m.loss = 0
 
-        bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+        bs = self.shifter(bs)
         bs = self.embedding(bs)
+        bs = self.positional_encoding(bs)
         bs = self.trunk(bs)
         bs = self.readout(bs)
         bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
@@ -381,33 +400,6 @@ class MyGPT(nn.Module):
 
         return bs
 
-    def encode(self, bs):
-        bs = self.embedding(bs)
-        z = self.encoder(bs)
-        return z
-
-    def decode(self, z_shape):
-        bs = self.decoder(z_shape)
-        bs = self.readout(bs)
-        return bs
-
-    def partial_forward(self, bs, start_layer=None, end_layer=None):
-        if start_layer is None:
-            # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
-            bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
-            bs = self.embedding(bs)
-            if end_layer is not None:
-                return self.trunk[:end_layer](bs)
-            else:
-                bs = self.trunk(bs)
-                bs = self.readout(bs)
-                return bs
-        else:
-            bs = self.trunk[start_layer:](bs)
-            bs = self.trunk(bs)
-            bs = self.readout(bs)
-            return bs
-
     def reset_transformations(self):
         self.temperature = 1.0
         for m in self.modules():