Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 16 Aug 2024 06:41:24 +0000 (08:41 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 16 Aug 2024 06:41:24 +0000 (08:41 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index dcd76ad..8802752 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -473,6 +473,12 @@ def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
         + quiz_machine.models_logprobas(
             model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
         )
+        + quiz_machine.models_logprobas(
+            model, c_quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
+        + quiz_machine.models_logprobas(
+            model, c_quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
         for model in models
     ]
 
@@ -516,11 +522,20 @@ def save_additional_results(n_epoch, model, models, c_quizzes_procedure):
 ######################################################################
 
 
-def model_proba_solutions(m, quizzes):
-    l = quiz_machine.models_logprobas(
-        m, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-    ) + quiz_machine.models_logprobas(
-        m, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+def model_proba_solutions(model, quizzes):
+    l = (
+        quiz_machine.models_logprobas(
+            model, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
+        + quiz_machine.models_logprobas(
+            model, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
+        + quiz_machine.models_logprobas(
+            model, quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
+        + quiz_machine.models_logprobas(
+            model, quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
     )
 
     return l.exp()
@@ -583,21 +598,27 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
         # the most consistent from a model which is confident
 
         for s in range(proba_own_solution.size(0)):
+            # At least one GPT does not understand at all
             if proba_own_solution[s, :].min() < args.proba_not_understands:
                 dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
                 nb_fails = dont_get_this_quiz.long().sum()
+                # At most max_fail_to_validate do not understand (default 3/5)
                 if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate:
                     for model in models:
+                        # If a GPT does not get that quiz
                         if dont_get_this_quiz[model.id]:
                             assert (
                                 proba_own_solution[s, model.id] < args.proba_understands
                             )
+                            # Look at its estimate of the others'solutions
                             proba_other_solutions = model_proba_solutions(
                                 model, solved_c_quizzes[s]
                             )
+                            # Randomize a bit the orders for the frequent P=1
                             proba_other_solutions += (
                                 torch.rand(proba_other_solutions.size()) * 1e-6
                             )
+                            # Remove the under threshold confidence solutions
                             proba_other_solutions[dont_get_this_quiz] = -1
                             i = proba_other_solutions.argmax()
                             model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
@@ -1121,6 +1142,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     for model in models:
         if model.test_accuracy >= model.best_test_accuracy:
+            log_string(
+                f"storing_best model {model.id} accuracy {model.best_test_accuracy} -> {model.test_accuracy}"
+            )
             model.best_dict = copy.deepcopy(model.state_dict())
             model.best_test_accuracy = model.test_accuracy
 
index 6da9075..3c4a865 100755 (executable)
@@ -85,14 +85,11 @@ class QuizMachine:
         self.train_structures = [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
             (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+            # (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+            # (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
             (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-            # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
-            # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
-            # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
-            # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
-            # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
         ]
 
         self.test_structures = self.train_structures
@@ -198,7 +195,7 @@ class QuizMachine:
             input=result,
             ar_mask=ar_mask,
             seq_logprobas=seq_logprobas,
-            progress_bar_desc="accuracy",
+            progress_bar_desc="autoregression",
         )
 
         correct = (result == quizzes).min(dim=1).values.long()