Update. sigma
authorFrançois Fleuret <francois@fleuret.org>
Mon, 5 Aug 2024 05:33:44 +0000 (07:33 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 5 Aug 2024 05:33:44 +0000 (07:33 +0200)
mygpt.py

index 812a139..4ffaa3d 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -17,6 +17,36 @@ import torch
 from torch import nn
 from torch.nn import functional as F
 
+######################################################################
+
+#
+# This function gets a NxT tensor of long that encodes the group id of
+# each token, and returns a NxT tensor sigma of long such that for any
+# n sigma[n, :] is a permutation of {0...T-1} sampled uniformly among
+# the permutations that verify
+#
+# for any n, i, j: group[n,i] < group[n,j] => sigma[n,i] < sigma[n,j]
+#
+# For instance
+#
+#   block_sigma(torch.tensor([[2, 2, 0, 0, 0, 1, 1, 1, 1, 2]]))
+#
+# could be
+#
+#   tensor([[8, 7, 1, 0, 2, 5, 4, 3, 6, 9]])
+#
+
+
+def block_sigma(groups):
+    g = (groups[:, None, :] == torch.arange(groups.max() + 1)[None, :, None]).long()
+    r = g * torch.rand(g.size()) + (1 - g) * 2
+    a = torch.arange(r.size(2)).repeat(r.size(0), r.size(1), 1)
+    s = a.new(r.size()).scatter_(dim=2, index=r.argsort(dim=2), src=a) * g
+    m = g.sum(dim=2).cumsum(dim=1)
+    s[:, 1:, :] += m[:, :-1, None]
+    return (s * g).sum(dim=1)
+
+
 ######################################################################
 
 # A BracketedSequence is a BxTx... tensor with a first and a nb time