From a146f072b14ae189fda9e866d1978e46e25701b4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 4 Aug 2024 14:11:07 +0200 Subject: [PATCH] Update. --- main.py | 18 ++---------------- quiz_machine.py | 21 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 7361ae8..ebdce3f 100755 --- a/main.py +++ b/main.py @@ -360,20 +360,6 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### -def sigma_for_grids(input): - l = input.size(1) // 4 - 1 - sigma = input.new(input.size()) - r = sigma.view(sigma.size(0), 4, sigma.size(1) // 4) - r[:, 0] = 0 * l - r[:, 1] = 1 * l - r[:, 2] = 2 * l - r[:, 3] = 3 * l - r[:, :, 1:] += ( - torch.rand(input.size(0), 4, l, device=input.device).sort(dim=2).indices - ) + 1 - return sigma - - def run_tests(model, quiz_machine, local_device=main_device): with torch.autograd.no_grad(): model.eval().to(local_device) @@ -386,7 +372,7 @@ def run_tests(model, quiz_machine, local_device=main_device): for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"): input = input.to(local_device) - sigma = sigma_for_grids(input) + sigma = quiz_machine.sigma_for_grids(input) output = model(mygpt.BracketedSequence(input), sigma).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) @@ -432,7 +418,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): targets = input - sigma = sigma_for_grids(input) + sigma = quiz_machine.sigma_for_grids(input) output = model(mygpt.BracketedSequence(input), sigma).x loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" diff --git a/quiz_machine.py b/quiz_machine.py index 3fc1066..6fd6579 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -27,6 +27,7 @@ import threading def one_batch_masked_inplace_autoregression( model, input, + sigma, ar_mask, seq_logproba, deterministic_synthesis=False, @@ -41,7 +42,7 @@ def one_batch_masked_inplace_autoregression( BracketedSequence(input, 0, to_generate.min()) ) # Needed to initialize the model's cache for s in range(to_generate.min(), to_generate.max() + 1): - output = model(BracketedSequence(input, s, 1)).x + output = model(BracketedSequence(input, s, 1), sigma).x logits = output[:, s] @@ -98,6 +99,19 @@ class QuizMachine: ###################################################################### + def sigma_for_grids(self, input): + l = input.size(1) // 4 + sigma = input.new(input.size()) + r = sigma.view(sigma.size(0), 4, l) + r[:, 0] = 0 * l + r[:, 1] = 1 * l + r[:, 2] = 2 * l + r[:, 3] = 3 * l + r[:, :, 1:] += ( + torch.rand(input.size(0), 4, l - 1, device=input.device).sort(dim=2).indices + ) + 1 + return sigma + def autoregression( self, model, @@ -130,9 +144,11 @@ class QuizMachine: model.eval() for input, ar_mask, seq_logproba in batches: + sigma = self.sigma_for_grids(input) one_batch_masked_inplace_autoregression( model=model, input=input, + sigma=sigma, ar_mask=ar_mask, seq_logproba=seq_logproba, deterministic_synthesis=False, @@ -360,7 +376,8 @@ class QuizMachine: ): input = input.to(device) ar_mask = self.make_ar_mask(input, struct=struct, mask=mask) - output = model(mygpt.BracketedSequence(input)).x + sigma = self.sigma_for_grids(input) + output = model(mygpt.BracketedSequence(input), sigma).x l[:, model.id] = ( -F.cross_entropy( output.transpose(1, 2), input, reduction="none" -- 2.39.5