Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 08:26:07 +0000 (10:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 08:26:07 +0000 (10:26 +0200)
quiz_machine.py

index ceb523d..92da03d 100755 (executable)
@@ -198,7 +198,7 @@ class QuizMachine:
                     quizzes[i] = self.problem.inject_noise(
                         quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
                     )
-                    quiz_mask_loss[i] = self.make_ar_mask(
+                    quiz_mask_loss[i] = self.make_quiz_mask(
                         quizzes=quizzes[i], struct=struct, mask=mask_loss
                     )
 
@@ -206,14 +206,14 @@ class QuizMachine:
 
     ######################################################################
 
-    def make_ar_mask(self, quizzes, struct, mask):
+    def make_quiz_mask(self, quizzes, struct, mask):
         assert struct in [s for s, _, _, _ in self.train_structures]
-        return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
+        return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask)
 
     ######################################################################
 
     def predict(self, model, quizzes, struct, mask):
-        ar_mask = self.make_ar_mask(quizzes=quizzes, struct=struct, mask=mask)
+        ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
         result = quizzes * (1 - ar_mask)
 
         seq_logproba = torch.empty(quizzes.size(0), device=self.device)
@@ -374,7 +374,7 @@ class QuizMachine:
                     seq_logproba.split(self.batch_size),
                 ):
                     input = input.to(device)
-                    quiz_mask_loss = self.make_ar_mask(
+                    quiz_mask_loss = self.make_quiz_mask(
                         input, struct=struct, mask=mask_loss
                     )
                     output = model(mygpt.BracketedSequence(input)).x
@@ -410,7 +410,7 @@ class QuizMachine:
             self.autoregression(
                 model=model_for_generation,
                 input=c_quizzes,
-                ar_mask=self.make_ar_mask(c_quizzes, s, m),
+                ar_mask=self.make_quiz_mask(c_quizzes, s, m),
                 seq_logproba=seq_logproba,
             )