Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 27 Aug 2024 14:59:16 +0000 (16:59 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 27 Aug 2024 14:59:16 +0000 (16:59 +0200)
main.py

diff --git a/main.py b/main.py
index 2e8ec43..0ecd492 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1133,7 +1133,66 @@ for i in range(args.nb_models):
 
     models.append(model)
 
+######################################################################
+
+current_epoch = 0
+
+if args.resume:
+    for model in models:
+        filename = f"ae_{model.id:03d}.pth"
+
+        try:
+            d = torch.load(os.path.join(args.result_dir, filename))
+            model.load_state_dict(d["state_dict"])
+            model.optimizer.load_state_dict(d["optimizer_state_dict"])
+            model.test_accuracy = d["test_accuracy"]
+            # model.gen_test_accuracy = d["gen_test_accuracy"]
+            # model.gen_state_dict = d["gen_state_dict"]
+            # model.train_c_quiz_bags = d["train_c_quiz_bags"]
+            # model.test_c_quiz_bags = d["test_c_quiz_bags"]
+            log_string(f"successfully loaded {filename}")
+        except FileNotFoundError:
+            log_string(f"cannot find {filename}")
+            pass
+
+    try:
+        filename = "state.pth"
+        state = torch.load(os.path.join(args.result_dir, filename))
+        log_string(f"successfully loaded {filename}")
+        current_epoch = state["current_epoch"]
+        # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
+        # total_time_training_models = state["total_time_training_models"]
+        # common_c_quiz_bags = state["common_c_quiz_bags"]
+    except FileNotFoundError:
+        log_string(f"cannot find {filename}")
+        pass
+
+######################################################################
+
+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
+
+######################################################################
+
 for n_epoch in range(args.nb_epochs):
+    state = {
+        "current_epoch": n_epoch,
+        # "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
+        # "total_time_training_models": total_time_training_models,
+        # "common_c_quiz_bags": common_c_quiz_bags,
+    }
+    filename = "state.pth"
+    torch.save(state, os.path.join(args.result_dir, filename))
+    log_string(f"wrote {filename}")
+
+    log_string(f"--- epoch {n_epoch} ----------------------------------------")
+
+    cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
+    log_string(f"current_test_accuracies {cta}")
+
+    # --------------------------------------------------------------------
+
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
 
@@ -1155,6 +1214,24 @@ for n_epoch in range(args.nb_epochs):
         for t in threads:
             t.join()
 
+    # --------------------------------------------------------------------
+
+    for model in models:
+        filename = f"ae_{model.id:03d}.pth"
+        torch.save(
+            {
+                "state_dict": model.state_dict(),
+                "optimizer_state_dict": model.optimizer.state_dict(),
+                "test_accuracy": model.test_accuracy,
+                # "gen_test_accuracy": model.gen_test_accuracy,
+                # "gen_state_dict": model.gen_state_dict,
+                # "train_c_quiz_bags": model.train_c_quiz_bags,
+                # "test_c_quiz_bags": model.test_c_quiz_bags,
+            },
+            os.path.join(args.result_dir, filename),
+        )
+        log_string(f"wrote {filename}")
+
 ######################################################################
 
 
@@ -1213,14 +1290,14 @@ models = create_models()
 
 ######################################################################
 
-current_epoch = 0
-
 # We balance the computing time between training the models and
 # generating c_quizzes
 
 total_time_generating_c_quizzes = 0
 total_time_training_models = 0
 
+current_epoch = 0
+
 if args.resume:
     for model in models:
         filename = f"gpt_{model.id:03d}.pth"