Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 07:11:09 +0000 (10:11 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 1 Jul 2024 07:11:09 +0000 (10:11 +0300)
quizz_machine.py

index 84bb558..7b0b877 100755 (executable)
@@ -312,9 +312,7 @@ class QuizzMachine:
         else:
             self.test_c_quizzes.append(new_c_quizzes)
 
-    def comput_correctness(self, c_quizzes, models_for_validation):
-        # Create the reverse quizzes
-
+    def reverse_time(self, c_quizzes):
         token_forward, token_backward = self.problem.direction_tokens()
 
         l = (c_quizzes.size(1) - 1) // 2
@@ -322,9 +320,11 @@ class QuizzMachine:
         direction = self.problem.token_forward * (
             direction == self.problem.token_backward
         ) + self.problem.token_backward * (direction == self.problem.token_forward)
-        reverse_c_quizzes = torch.cat(
-            [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
-        )
+
+        return torch.cat([c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1)
+
+    def comput_correctness(self, c_quizzes, models_for_validation):
+        reversed_c_quizzes = self.reverse_time(c_quizzes)
 
         ar_mask = self.make_ar_mask(c_quizzes)
         seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
@@ -350,12 +350,12 @@ class QuizzMachine:
 
             correct = (c_quizzes == result).long().min(dim=-1).values
 
-            reverse_result = reverse_c_quizzes.clone()
+            reversed_result = reversed_c_quizzes.clone()
 
             masked_inplace_autoregression(
                 model=model,
                 batch_size=self.batch_size,
-                input=reverse_result,
+                input=reversed_result,
                 ar_mask=ar_mask,
                 seq_logproba=seq_logproba,
                 temperature=1.0,
@@ -364,17 +364,19 @@ class QuizzMachine:
                 device=self.device,
             )
 
-            reverse_correct = (
-                (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
+            reversed_correct = (
+                (reversed_c_quizzes == reversed_result).long().min(dim=-1).values
             )
 
-            nb_correct += correct * reverse_correct
+            nb_correct += correct * reversed_correct
 
         return nb_correct
 
     ###############################################################
 
-    def generate_quizzes(self, nb, model_for_generation, min_ave_seq_logproba):
+    def generate_quizzes(
+        self, nb, model_for_generation, min_ave_seq_logproba, reverse_cleanup=False
+    ):
         c_quizzes = torch.empty(
             nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
         )
@@ -384,11 +386,12 @@ class QuizzMachine:
         ar_mask_solve = 1 - ar_mask_prompt
         seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
 
-        warnings.warn("noise injection", RuntimeWarning)
+        warnings.warn("noise injection", RuntimeWarning)
         temperature = 1
-        noise_std = torch.rand(1).item()
-        self.logger(f"{noise_std=}")
-        mygpt.set_noise_injection(model_for_generation, noise_std)
+        # noise_std = torch.rand(1).item()
+        # self.logger(f"{noise_std=}")
+
+        # mygpt.set_noise_injection(model_for_generation, noise_std)
 
         masked_inplace_autoregression(
             model=model_for_generation,
@@ -402,6 +405,8 @@ class QuizzMachine:
             device=self.device,
         )
 
+        # mygpt.set_noise_injection(model_for_generation, 0.0)
+
         ave_seq_logproba = seq_logproba.mean()
 
         masked_inplace_autoregression(
@@ -416,7 +421,19 @@ class QuizzMachine:
             device=self.device,
         )
 
-        mygpt.set_noise_injection(model_for_generation, 0.0)
+        if reverse_cleanup:
+            c_quizzes = self.reverse_time(c_quizzes)
+            masked_inplace_autoregression(
+                model=model_for_generation,
+                batch_size=self.batch_size,
+                input=c_quizzes,
+                ar_mask=ar_mask_solve,
+                seq_logproba=seq_logproba,
+                temperature=temperature,
+                deterministic_synthesis=True,
+                # progress_bar_desc="sampling c_quizzes",
+                device=self.device,
+            )
 
         return c_quizzes, seq_logproba.mean()