Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 16 Aug 2024 17:52:32 +0000 (19:52 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 16 Aug 2024 17:52:32 +0000 (19:52 +0200)
main.py

diff --git a/main.py b/main.py
index 0ca8153..f1fb834 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -919,6 +919,8 @@ for k in range(args.nb_gpts):
 ######################################################################
 
 current_epoch = 0
+total_time_generating_c_quizzes = 0
+total_time_training_models = 0
 
 if args.resume:
     for model in models:
@@ -943,6 +945,8 @@ if args.resume:
         state = torch.load(os.path.join(args.result_dir, filename))
         log_string(f"successfully loaded {filename}")
         current_epoch = state["current_epoch"]
+        total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
+        total_time_training_models = state["total_time_training_models"]
     except FileNotFoundError:
         log_string(f"cannot find {filename}")
         pass
@@ -973,7 +977,6 @@ if args.dirty_debug:
     args.nb_new_c_quizzes_for_test = 10
 
 ######################################################################
-######################################################################
 
 
 class Folder(nn.Module):
@@ -1126,7 +1129,11 @@ if args.test == "reject":
 ######################################################################
 
 for n_epoch in range(current_epoch, args.nb_epochs):
-    state = {"current_epoch": n_epoch}
+    state = {
+        "current_epoch": n_epoch,
+        "total_time_training_models": total_time_training_models,
+        "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
+    }
     filename = "state.pth"
     torch.save(state, os.path.join(args.result_dir, filename))
     log_string(f"wrote {filename}")
@@ -1148,17 +1155,26 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             model.best_dict = copy.deepcopy(model.state_dict())
             model.best_test_accuracy = model.test_accuracy
 
-    if min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+    # we restart
+    if total_time_generating_c_quizzes == 0:
+        total_time_training_models = 0
+
+    if (
+        min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
+        and total_time_training_models > total_time_generating_c_quizzes
+    ):
         for model in models:
             model.current_dict = copy.deepcopy(model.state_dict())
             model.load_state_dict(model.best_dict)
 
+        start_time = time.perf_counter()
         record_new_c_quizzes(
             models,
             quiz_machine,
             args.nb_new_c_quizzes_for_train,
             args.nb_new_c_quizzes_for_test,
         )
+        total_time_generating_c_quizzes += time.perf_counter() - start_time
 
         # Force one epoch of training
         for model in models:
@@ -1167,29 +1183,33 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ##################################################
     # Select, improve, and eval the worst model(s)
 
-    ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+    if total_time_training_models <= total_time_generating_c_quizzes:
+        ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
 
-    weakest_models = ranked_models[: len(gpus)]
+        weakest_models = ranked_models[: len(gpus)]
 
-    threads = []
+        threads = []
 
-    for gpu, model in zip(gpus, weakest_models):
-        log_string(f"training model {model.id}")
+        start_time = time.perf_counter()
+        for gpu, model in zip(gpus, weakest_models):
+            log_string(f"training model {model.id}")
 
-        t = threading.Thread(
-            target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
-        )
+            t = threading.Thread(
+                target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+            )
+
+            threads.append(t)
 
-        threads.append(t)
+            t.start()
 
-        t.start()
+        for t in threads:
+            t.join()
 
-    for t in threads:
-        t.join()
+        total_time_training_models += time.perf_counter() - start_time
 
     # Save the models to disk
 
-    for model in weakest_models:
+    for model in models:
         filename = f"gpt_{model.id:03d}.pth"
         torch.save(
             {