Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 08:50:01 +0000 (10:50 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 08:50:01 +0000 (10:50 +0200)
main.py

diff --git a/main.py b/main.py
index 63f6cce..55615ee 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -360,6 +360,20 @@ log_string(f"vocabulary_size {vocabulary_size}")
 ######################################################################
 
 
+def sigma_for_grids(input):
+    l = input.size(1) // 4 - 1
+    sigma = input.new(input.size())
+    r = sigma.view(sigma.size(0), 4, sigma.size(1) // 4)
+    r[:, :, 1:] = (
+        torch.rand(input.size(0), 4, l, device=input.device).sort(dim=2).indices
+    ) + 1
+    r[:, 0] += 0 * l
+    r[:, 1] += 1 * l
+    r[:, 2] += 2 * l
+    r[:, 3] += 3 * l
+    return sigma
+
+
 def run_tests(model, quiz_machine, local_device=main_device):
     with torch.autograd.no_grad():
         model.eval().to(local_device)
@@ -372,7 +386,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
             input = input.to(local_device)
-            sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices
+            sigma = sigma_for_grids(input)
             output = model(mygpt.BracketedSequence(input), sigma).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
@@ -418,7 +432,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
         targets = input
 
-        sigma = torch.rand(input.size(), device=input.device).sort(dim=1).indices
+        sigma = sigma_for_grids(input)
         output = model(mygpt.BracketedSequence(input), sigma).x
         loss_per_token = F.cross_entropy(
             output.transpose(1, 2), targets, reduction="none"