Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 15 Aug 2024 14:58:40 +0000 (16:58 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 15 Aug 2024 14:58:40 +0000 (16:58 +0200)
main.py

diff --git a/main.py b/main.py
index 27404e8..e2e9a59 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -364,7 +364,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
 
-        model.main_test_accuracy = quiz_machine.produce_results(
+        model.test_accuracy = quiz_machine.produce_results(
             n_epoch=n_epoch,
             model=model,
             input=full_input[:2000],
@@ -890,7 +890,7 @@ for k in range(args.nb_gpts):
     else:
         model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
-    model.main_test_accuracy = 0.0
+    model.test_accuracy = 0.0
 
     models.append(model)
 
@@ -906,7 +906,9 @@ if args.resume:
             d = torch.load(os.path.join(args.result_dir, filename))
             model.load_state_dict(d["state_dict"])
             model.optimizer.load_state_dict(d["optimizer_state_dict"])
-            model.main_test_accuracy = d["main_test_accuracy"]
+            model.test_accuracy = d["test_accuracy"]
+            model.best_test_accuracy = d["best_test_accuracy"]
+            model.best_dict = d["best_dict"]
             model.train_c_quiz_bags = d["train_c_quiz_bags"]
             model.test_c_quiz_bags = d["test_c_quiz_bags"]
             log_string(f"successfully loaded {filename}")
@@ -1109,14 +1111,23 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     log_string(f"--- epoch {n_epoch} ----------------------------------------")
 
-    cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
+    cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
     log_string(f"current_test_accuracies {cta}")
 
     ##################################################
     # If all the models are good enough, generate new quizzes and
     # re-compute the test errors
 
-    if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+    for model in models:
+        if model.test_accuracy >= model.best_test_accuracy:
+            model.best_dict = copy.deepcopy(model.state_dict())
+            model.best_test_accuracy = model.test_accuracy
+
+    if min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+        for model in models:
+            model.current_dict = copy.deepcopy(model.state_dict())
+            model.load_state_dict(model.best_dict)
+
         record_new_c_quizzes(
             models,
             quiz_machine,
@@ -1126,12 +1137,12 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         # Force one epoch of training
         for model in models:
-            model.main_test_accuracy = 0.0
+            model.load_state_dict(model.current_dict)
 
     ##################################################
     # Select, improve, and eval the worst model(s)
 
-    ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
+    ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
 
     weakest_models = ranked_models[: len(gpus)]
 
@@ -1159,7 +1170,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             {
                 "state_dict": model.state_dict(),
                 "optimizer_state_dict": model.optimizer.state_dict(),
-                "main_test_accuracy": model.main_test_accuracy,
+                "test_accuracy": model.test_accuracy,
+                "best_test_accuracy": model.best_test_accuracy,
+                "best_dict": model.best_dict,
                 "train_c_quiz_bags": model.train_c_quiz_bags,
                 "test_c_quiz_bags": model.test_c_quiz_bags,
             },