From 702e672dcf9ebcfad11ae4034e64117f2c67ead5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 24 Jun 2024 12:13:01 +0200 Subject: [PATCH] Update. --- mygpt.py | 7 ++++++- tasks.py | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mygpt.py b/mygpt.py index a178491..c58bea1 100755 --- a/mygpt.py +++ b/mygpt.py @@ -292,11 +292,16 @@ class MyGPT(nn.Module): ) # Needed to initialize the model's cache for s in range(to_generate.min(), to_generate.max() + 1): output = self(BracketedSequence(input, s, 1)).x - logits = output[:, s] / temperature + logits = output[:, s] + + logits = logits.log_softmax(dim=-1) / temperature + if forbidden_tokens is not None: logits = logits.masked_fill(forbidden_tokens, float("-inf")) + if forced_biases is not None: logits = logits + forced_biases[None, :] + if deterministic_synthesis: t_next = logits.argmax(1) else: diff --git a/tasks.py b/tasks.py index b967465..5edb472 100755 --- a/tasks.py +++ b/tasks.py @@ -274,6 +274,10 @@ class World(Task): device=self.device, ) + # Should not be necessary though, the autoregression is done + # in eval mode + sum_logits = sum_logits.detach() + average_logits = sum_logits / quizzes.numel() # It's a bit brutal to do it twice, we should probably have a -- 2.39.5