From 896806aefa647e546977b1bea9b09362fa59f1ea Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 4 Aug 2024 10:50:01 +0200 Subject: [PATCH] Update. --- main.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 63f6cce..55615ee 100755 --- a/main.py +++ b/main.py @@ -360,6 +360,20 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### +def sigma_for_grids(input): + l = input.size(1) // 4 - 1 + sigma = input.new(input.size()) + r = sigma.view(sigma.size(0), 4, sigma.size(1) // 4) + r[:, :, 1:] = ( + torch.rand(input.size(0), 4, l, device=input.device).sort(dim=2).indices + ) + 1 + r[:, 0] += 0 * l + r[:, 1] += 1 * l + r[:, 2] += 2 * l + r[:, 3] += 3 * l + return sigma + + def run_tests(model, quiz_machine, local_device=main_device): with torch.autograd.no_grad(): model.eval().to(local_device) @@ -372,7 +386,7 @@ def run_tests(model, quiz_machine, local_device=main_device): for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"): input = input.to(local_device) - sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices + sigma = sigma_for_grids(input) output = model(mygpt.BracketedSequence(input), sigma).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) @@ -418,7 +432,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): targets = input - sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices + sigma = sigma_for_grids(input) output = model(mygpt.BracketedSequence(input), sigma).x loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" -- 2.39.5