Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 15:47:44 +0000 (17:47 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 15:47:44 +0000 (17:47 +0200)
main.py

diff --git a/main.py b/main.py
index 34c0987..b2a9591 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -545,7 +545,7 @@ def model_proba_solutions(m, quizzes):
 
 
 def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
-    nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test
+    nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
     nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
 
     start_time = time.perf_counter()
@@ -593,22 +593,26 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
                 mask=(0, 0, 0, 1),
             )
 
-            u = model_proba_solutions(model, solved_c_quizzes[:, model.id])
-
-            proba_own_solution[:, model.id] = u
+            proba_own_solution[:, model.id] = model_proba_solutions(
+                model, solved_c_quizzes[:, model.id]
+            )
 
         # Now for every model not confident of its response, we pick
         # the most consistent from a model which is confident
 
         for s in range(proba_own_solution.size(0)):
-            dont_get_it = proba_own_solution[s, :] < args.proba_understands
-            if not dont_get_it.all():
+            dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
+            if not dont_get_this_quiz.all():
                 for model in models:
-                    if dont_get_it[model.id]:
+                    if dont_get_this_quiz[model.id]:
+                        assert proba_own_solution[s, model.id] < args.proba_understands
                         proba_other_solutions = model_proba_solutions(
                             model, solved_c_quizzes[s]
                         )
-                        proba_other_solutions[dont_get_it] = -1
+                        proba_other_solutions[dont_get_this_quiz] = -1
+                        # print(
+                        # f"\nDEBUG {proba_own_solution[s,model.id]=} {proba_other_solutions=}\n"
+                        # )
                         i = proba_other_solutions.argmax()
                         model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
                         teaching_count[i, model.id] += 1
@@ -647,34 +651,12 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
 
             c_quizzes = new_bag[:128]
 
-            l = [
-                quiz_machine.models_logprobas(
-                    model,
-                    c_quizzes,
-                    ("A", "f_A", "B", "f_B"),
-                    (0, 0, 0, 1),
-                    (0, 0, 1, 0),
-                )
-                + quiz_machine.models_logprobas(
-                    model,
-                    c_quizzes,
-                    ("f_A", "A", "f_B", "B"),
-                    (0, 0, 0, 1),
-                    (0, 0, 1, 0),
-                )
-                for model in models
-            ]
-
-            seq_logprobas = torch.cat([x[:, None] for x in l], dim=1)
-
-            probas = seq_logprobas.exp()
-
+            l = [model_proba_solutions(model, c_quizzes) for model in models]
+            probas = torch.cat([x[:, None] for x in l], dim=1)
             comments = []
 
-            for l in seq_logprobas:
-                comments.append(
-                    "proba " + " ".join([f"{x.exp().item():.02f}" for x in l])
-                )
+            for l in probas:
+                comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
 
             filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png"
             quiz_machine.problem.save_quizzes_as_image(
@@ -699,7 +681,7 @@ if args.schedule_free:
     import schedulefree
 
 for k in range(args.nb_gpts):
-    log_string(f"creating model {k} and its w_quizzes")
+    log_string(f"creating model {k}")
 
     model = mygpt.MyGPT(
         vocabulary_size=vocabulary_size,
@@ -764,10 +746,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 ######################################################################
 
 if args.nb_new_c_quizzes_for_train is None:
-    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100
+    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 40
 
 if args.nb_new_c_quizzes_for_test is None:
-    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 100
+    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 40
 
 log_string(
     f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"