Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 4 Jul 2024 11:30:36 +0000 (14:30 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 4 Jul 2024 11:30:36 +0000 (14:30 +0300)
quizz_machine.py

index 6e57fb4..ee7af90 100755 (executable)
@@ -122,12 +122,13 @@ class QuizzMachine:
         forward_to_backward = torch.cat(
             [
                 quizzes[:, 0:1],
-                quizzes[:, 2 + self.prompt_len :],
-                quizzes[:, 1 + self.prompt_len : 2 + self.prompt_len],
+                quizzes[:, 2 + self.prompt_len : 2 + self.prompt_len + self.answer_len],
+                quizzes[:, 1 + self.prompt_len : 1 + self.prompt_len + 1],
                 quizzes[:, 1 : 1 + self.prompt_len],
             ],
             dim=1,
         )
+
         forward_to_backward[:, 0] = self.token_backward
         forward_to_backward[:, 1 + self.answer_len] = self.token_backward
 
@@ -234,14 +235,14 @@ class QuizzMachine:
 
         if result_dir is not None:
             self.save_quizzes(
-                result_dir, "culture_w_quizzes", self.train_w_quizzes[:72]
+                result_dir,
+                "culture_w_quizzes",
+                self.train_w_quizzes[:72],
+                prediction=True,
             )
 
-            # toto = self.reverse_time(self.train_w_quizzes[:72])
-            # self.save_quizzes(result_dir, "toto", toto)
-            # exit(0)
-
     def save_quizzes(self, result_dir, filename_prefix, quizzes, prediction=False):
+        quizzes = quizzes.clone()
         forward = quizzes[quizzes[:, 0] == self.token_forward]
         ib = quizzes[:, 0] == self.token_backward
         backward = quizzes[ib]
@@ -326,6 +327,21 @@ class QuizzMachine:
                 device=self.device,
             )
 
+            #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+            self.save_quizzes(
+                result_dir,
+                f"DEBUG_input_{n_epoch}_{result.size(0):04d}",
+                quizzes=input[:72],
+                prediction=True,
+            )
+            self.save_quizzes(
+                result_dir,
+                f"DEBUG_result_{n_epoch}_{result.size(0):04d}",
+                quizzes=result[:72],
+                prediction=True,
+            )
+            #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
             if self.back_accuracy:
                 n_forward = input[:, 0] == self.token_forward
                 nb_total = input[n_forward].size(0)
@@ -334,21 +350,33 @@ class QuizzMachine:
                     .long()
                     .min(dim=1)
                     .values.sum()
+                    .item()
+                )
+
+                self.logger(
+                    f"back_accuracy {n_epoch=} {model.id=} {nb_correct=} {nb_total=}"
                 )
 
                 n_backward = input[:, 0] == self.token_backward
                 back_input = self.reverse_time(result[n_backward])
+
                 if back_input.size(0) > 0:
                     back_input[:, 2 + self.prompt_len :] = input[
-                        n_backward, 2 + self.prompt_len :
+                        n_backward, 1 : 1 + self.answer_len
                     ]
                     back_nb_total, back_nb_correct = compute_accuracy(back_input)
+                    self.logger(
+                        f"back_accuracy {n_epoch=} {model.id=} {back_nb_correct=} {back_nb_total=}"
+                    )
                     nb_total += back_nb_total
                     nb_correct += back_nb_correct
+
             else:
                 nb_total = input.size(0)
                 nb_correct = (input == result).long().min(dim=1).values.sum()
 
+            exit(0)
+
             return nb_total, nb_correct
 
         train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes[:nmax])