Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 14:38:59 +0000 (16:38 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 14:38:59 +0000 (16:38 +0200)
mygpt.py
quiz_machine.py

index b1cdf4d..812a139 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -348,8 +348,16 @@ class MyGPT(nn.Module):
                     m.bias.zero_()
                     m.weight.fill_(1.0)
 
+    # x[      0 ], x[      1 ], ..., x[      T-2 ], x[      T-1 ]
+    # x[sigma[0]], x[sigma[1]], ..., x[sigma[T-2]], x[sigma[T-1]]
+    # x[     -1 ], x[sigma[0]], ..., x[sigma[T-3]], x[sigma[T-2]]
+
+    # y[sigma[0]], y[sigma[1]], ..., y[sigma[T-2]], y[sigma[T-1]]
+    # y[      0 ], y[      1 ], ..., y[      T-2 ], y[      T-1 ]
+
     def forward(self, bs, sigma=None):
         if sigma is not None:
+            # x[n,t] = x[n,sigma[n,t]]
             bs.x = bs.x.gather(dim=1, index=sigma)
         bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
         bs = self.embedding(bs)
@@ -358,9 +366,9 @@ class MyGPT(nn.Module):
         bs = self.readout(bs)
         bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
         if sigma is not None:
-            bs.x.scatter_(
-                dim=1, index=sigma[:, :, None].expand_as(bs.x), src=bs.x.clone()
-            )
+            y = bs.x.new_zeros(bs.x.size())
+            y.scatter_(dim=1, index=sigma[:, :, None].expand_as(bs.x), src=bs.x)
+            bs.x = y
         return bs
 
     def encode(self, bs):
index 6fd6579..386969a 100755 (executable)
@@ -35,16 +35,11 @@ def one_batch_masked_inplace_autoregression(
     if input.size(0) == 0:
         return
 
-    to_generate = (ar_mask.sum(0) > 0).nonzero()
-
-    if to_generate.min() > 0:
-        model(
-            BracketedSequence(input, 0, to_generate.min())
-        )  # Needed to initialize the model's cache
-    for s in range(to_generate.min(), to_generate.max() + 1):
+    for s in range(input.size(1)):
         output = model(BracketedSequence(input, s, 1), sigma).x
-
-        logits = output[:, s]
+        all_n = torch.arange(input.size(0), device=input.device)
+        u = sigma[:, s]
+        logits = output[all_n, u]
 
         if deterministic_synthesis:
             t_next = logits.argmax(-1)
@@ -52,11 +47,11 @@ def one_batch_masked_inplace_autoregression(
             dist = torch.distributions.categorical.Categorical(logits=logits)
             t_next = dist.sample()
 
-        all_n = torch.arange(t_next.size(0))
-
         seq_logproba += logits[all_n, t_next]
 
-        input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+        input[all_n, u] = (
+            ar_mask[all_n, u] * t_next + (1 - ar_mask[all_n, u]) * input[all_n, u]
+        )
 
 
 ######################################################################