From f9a093b44641d5231db9aad352a9b35b47a8c312 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 4 Aug 2024 16:38:59 +0200 Subject: [PATCH] Update. --- mygpt.py | 14 +++++++++++--- quiz_machine.py | 19 +++++++------------ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/mygpt.py b/mygpt.py index b1cdf4d..812a139 100755 --- a/mygpt.py +++ b/mygpt.py @@ -348,8 +348,16 @@ class MyGPT(nn.Module): m.bias.zero_() m.weight.fill_(1.0) + # x[ 0 ], x[ 1 ], ..., x[ T-2 ], x[ T-1 ] + # x[sigma[0]], x[sigma[1]], ..., x[sigma[T-2]], x[sigma[T-1]] + # x[ -1 ], x[sigma[0]], ..., x[sigma[T-3]], x[sigma[T-2]] + + # y[sigma[0]], y[sigma[1]], ..., y[sigma[T-2]], y[sigma[T-1]] + # y[ 0 ], y[ 1 ], ..., y[ T-2 ], y[ T-1 ] + def forward(self, bs, sigma=None): if sigma is not None: + # x[n,t] = x[n,sigma[n,t]] bs.x = bs.x.gather(dim=1, index=sigma) bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) bs = self.embedding(bs) @@ -358,9 +366,9 @@ class MyGPT(nn.Module): bs = self.readout(bs) bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature if sigma is not None: - bs.x.scatter_( - dim=1, index=sigma[:, :, None].expand_as(bs.x), src=bs.x.clone() - ) + y = bs.x.new_zeros(bs.x.size()) + y.scatter_(dim=1, index=sigma[:, :, None].expand_as(bs.x), src=bs.x) + bs.x = y return bs def encode(self, bs): diff --git a/quiz_machine.py b/quiz_machine.py index 6fd6579..386969a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -35,16 +35,11 @@ def one_batch_masked_inplace_autoregression( if input.size(0) == 0: return - to_generate = (ar_mask.sum(0) > 0).nonzero() - - if to_generate.min() > 0: - model( - BracketedSequence(input, 0, to_generate.min()) - ) # Needed to initialize the model's cache - for s in range(to_generate.min(), to_generate.max() + 1): + for s in range(input.size(1)): output = model(BracketedSequence(input, s, 1), sigma).x - - logits = output[:, s] + all_n = torch.arange(input.size(0), device=input.device) + u = sigma[:, s] + logits = output[all_n, u] if deterministic_synthesis: t_next = logits.argmax(-1) @@ -52,11 +47,11 @@ def one_batch_masked_inplace_autoregression( dist = torch.distributions.categorical.Categorical(logits=logits) t_next = dist.sample() - all_n = torch.arange(t_next.size(0)) - seq_logproba += logits[all_n, t_next] - input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] + input[all_n, u] = ( + ar_mask[all_n, u] * t_next + (1 - ar_mask[all_n, u]) * input[all_n, u] + ) ###################################################################### -- 2.39.5