Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 20:19:46 +0000 (22:19 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 20:19:46 +0000 (22:19 +0200)
main.py

diff --git a/main.py b/main.py
index cd1e10f..901e91c 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -456,9 +456,9 @@ def model_modifier_cold(model):
 
 
 c_quizzes_procedure = [
-    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
-    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
-    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
+    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
+    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
+    (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot),
     # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
 ]
 
@@ -562,8 +562,9 @@ def create_c_quizzes(
     train_c_quiz_bags,
     nb_for_test,
     test_c_quiz_bags,
+    local_device=main_device,
 ):
-    nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
+    nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test
     nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
 
     start_time = time.perf_counter()
@@ -600,10 +601,9 @@ def create_c_quizzes(
             mask=(0, 0, 0, 1),
         )
 
-        keep = (
-            model_proba_solutions(main_model, main_solution)
-            < args.proba_not_understands
-        )
+        main_probas = model_proba_solutions(main_model, main_solution)
+        log_string(f"main_probas {main_probas}")
+        keep = main_probas < args.proba_not_understands
         c_quizzes = c_quizzes[keep]
 
         # If there are some quizzes that the main model cannot solve,
@@ -624,6 +624,7 @@ def create_c_quizzes(
                 )
 
                 probas = model_proba_solutions(model, solution)
+                log_string(f"probas {probas}")
                 keep = probas >= c_quizzes_proba
                 c_quizzes = solution[keep]
                 c_quizzes_proba[keep] = probas[keep]
@@ -652,19 +653,9 @@ def create_c_quizzes(
     # Save some images
 
     c_quizzes = torch.cat(recorded, dim=0)
-
-    l = [
-        model_proba_solutions(model, c_quizzes) for model in [main_model] + other_models
-    ]
-    probas = torch.cat([x[:, None] for x in l], dim=1)
-    comments = []
-    for l in probas:
-        comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
-
-    filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png"
-    quiz_machine.problem.save_quizzes_as_image(
-        args.result_dir, filename, c_quizzes[:128], comments=comments
-    )
+    n = (c_quizzes.size(0) * nb_for_train) // (nb_for_train + nb_for_test)
+    train_c_quiz_bags.append(c_quizzes[:n])
+    test_c_quiz_bags.append(c_quizzes[n:])
 
     log_string(
         f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in train_c_quiz_bags ])} test {sum([q.size(0) for q in test_c_quiz_bags ])}"
@@ -750,8 +741,8 @@ if args.resume:
         state = torch.load(os.path.join(args.result_dir, filename))
         log_string(f"successfully loaded {filename}")
         current_epoch = state["current_epoch"]
-        train_c_quiz_bags = d["train_c_quiz_bags"]
-        test_c_quiz_bags = d["test_c_quiz_bags"]
+        train_c_quiz_bags = state["train_c_quiz_bags"]
+        test_c_quiz_bags = state["test_c_quiz_bags"]
     except FileNotFoundError:
         log_string(f"cannot find {filename}")
         pass
@@ -865,7 +856,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
 
     cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
-    log_string(f"current_test_accuracies {cta}")
+    log_string(f"test_accuracies {cta}")
 
     ##################################################
 
@@ -880,6 +871,17 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             test_c_quiz_bags=test_c_quiz_bags,
         )
 
+        c_quizzes = train_c_quiz_bags[-128:]
+        l = [model_proba_solutions(model, c_quizzes) for model in models]
+        probas = torch.cat([x[:, None] for x in l], dim=1)
+        comments = []
+        for l in probas:
+            comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+        filename = f"culture_c_quiz_{n_epoch:04d}.png"
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir, filename, c_quizzes, comments=comments
+        )
+
         for model in models:
             new_model = mygpt.MyGPT(
                 vocabulary_size=vocabulary_size,
@@ -897,13 +899,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ##################################################
     # Select, improve, and eval the worst model(s)
 
-    ranked_models = sorted(
-        models,
-        # This ugly recipe will pick the worst if there some below
-        # args.accuracy_to_make_c_quizzes or one at random if they
-        # are all above
-        key=lambda m: float(m.test_accuracy),
-    )
+    ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
 
     weakest_models = ranked_models[: len(gpus)]
 
@@ -925,8 +921,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     for t in threads:
         t.join()
 
-    for model in weakest_models:
-        save_additional_results(n_epoch, model, models, c_quizzes_procedure)
+    for model in weakest_models:
+    # save_additional_results(n_epoch, model, models, c_quizzes_procedure)
 
     # Save the models to disk