Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 12:11:07 +0000 (14:11 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 12:11:07 +0000 (14:11 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index 7361ae8..ebdce3f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -360,20 +360,6 @@ log_string(f"vocabulary_size {vocabulary_size}")
 ######################################################################
 
 
-def sigma_for_grids(input):
-    l = input.size(1) // 4 - 1
-    sigma = input.new(input.size())
-    r = sigma.view(sigma.size(0), 4, sigma.size(1) // 4)
-    r[:, 0] = 0 * l
-    r[:, 1] = 1 * l
-    r[:, 2] = 2 * l
-    r[:, 3] = 3 * l
-    r[:, :, 1:] += (
-        torch.rand(input.size(0), 4, l, device=input.device).sort(dim=2).indices
-    ) + 1
-    return sigma
-
-
 def run_tests(model, quiz_machine, local_device=main_device):
     with torch.autograd.no_grad():
         model.eval().to(local_device)
@@ -386,7 +372,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         for input in tqdm.tqdm(src, dynamic_ncols=True, desc="test"):
             input = input.to(local_device)
-            sigma = sigma_for_grids(input)
+            sigma = quiz_machine.sigma_for_grids(input)
             output = model(mygpt.BracketedSequence(input), sigma).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
@@ -432,7 +418,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
         targets = input
 
-        sigma = sigma_for_grids(input)
+        sigma = quiz_machine.sigma_for_grids(input)
         output = model(mygpt.BracketedSequence(input), sigma).x
         loss_per_token = F.cross_entropy(
             output.transpose(1, 2), targets, reduction="none"
index 3fc1066..6fd6579 100755 (executable)
@@ -27,6 +27,7 @@ import threading
 def one_batch_masked_inplace_autoregression(
     model,
     input,
+    sigma,
     ar_mask,
     seq_logproba,
     deterministic_synthesis=False,
@@ -41,7 +42,7 @@ def one_batch_masked_inplace_autoregression(
             BracketedSequence(input, 0, to_generate.min())
         )  # Needed to initialize the model's cache
     for s in range(to_generate.min(), to_generate.max() + 1):
-        output = model(BracketedSequence(input, s, 1)).x
+        output = model(BracketedSequence(input, s, 1), sigma).x
 
         logits = output[:, s]
 
@@ -98,6 +99,19 @@ class QuizMachine:
 
     ######################################################################
 
+    def sigma_for_grids(self, input):
+        l = input.size(1) // 4
+        sigma = input.new(input.size())
+        r = sigma.view(sigma.size(0), 4, l)
+        r[:, 0] = 0 * l
+        r[:, 1] = 1 * l
+        r[:, 2] = 2 * l
+        r[:, 3] = 3 * l
+        r[:, :, 1:] += (
+            torch.rand(input.size(0), 4, l - 1, device=input.device).sort(dim=2).indices
+        ) + 1
+        return sigma
+
     def autoregression(
         self,
         model,
@@ -130,9 +144,11 @@ class QuizMachine:
             model.eval()
 
             for input, ar_mask, seq_logproba in batches:
+                sigma = self.sigma_for_grids(input)
                 one_batch_masked_inplace_autoregression(
                     model=model,
                     input=input,
+                    sigma=sigma,
                     ar_mask=ar_mask,
                     seq_logproba=seq_logproba,
                     deterministic_synthesis=False,
@@ -360,7 +376,8 @@ class QuizMachine:
                 ):
                     input = input.to(device)
                     ar_mask = self.make_ar_mask(input, struct=struct, mask=mask)
-                    output = model(mygpt.BracketedSequence(input)).x
+                    sigma = self.sigma_for_grids(input)
+                    output = model(mygpt.BracketedSequence(input), sigma).x
                     l[:, model.id] = (
                         -F.cross_entropy(
                             output.transpose(1, 2), input, reduction="none"