From a489b3662063ad37bf04bbc52cb2526dfd135c1f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 5 Aug 2024 07:33:44 +0200 Subject: [PATCH] Update. --- mygpt.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/mygpt.py b/mygpt.py index 812a139..4ffaa3d 100755 --- a/mygpt.py +++ b/mygpt.py @@ -17,6 +17,36 @@ import torch from torch import nn from torch.nn import functional as F +###################################################################### + +# +# This function gets a NxT tensor of long that encodes the group id of +# each token, and returns a NxT tensor sigma of long such that for any +# n sigma[n, :] is a permutation of {0...T-1} sampled uniformly among +# the permutations that verify +# +# for any n, i, j: group[n,i] < group[n,j] => sigma[n,i] < sigma[n,j] +# +# For instance +# +# block_sigma(torch.tensor([[2, 2, 0, 0, 0, 1, 1, 1, 1, 2]])) +# +# could be +# +# tensor([[8, 7, 1, 0, 2, 5, 4, 3, 6, 9]]) +# + + +def block_sigma(groups): + g = (groups[:, None, :] == torch.arange(groups.max() + 1)[None, :, None]).long() + r = g * torch.rand(g.size()) + (1 - g) * 2 + a = torch.arange(r.size(2)).repeat(r.size(0), r.size(1), 1) + s = a.new(r.size()).scatter_(dim=2, index=r.argsort(dim=2), src=a) * g + m = g.sum(dim=2).cumsum(dim=1) + s[:, 1:, :] += m[:, :-1, None] + return (s * g).sum(dim=1) + + ###################################################################### # A BracketedSequence is a BxTx... tensor with a first and a nb time -- 2.39.5