Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 16 Aug 2024 17:53:55 +0000 (19:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 16 Aug 2024 17:53:55 +0000 (19:53 +0200)
main.py

diff --git a/main.py b/main.py
index f1fb834..dc3fd1a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -945,8 +945,8 @@ if args.resume:
         state = torch.load(os.path.join(args.result_dir, filename))
         log_string(f"successfully loaded {filename}")
         current_epoch = state["current_epoch"]
-        total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
-        total_time_training_models = state["total_time_training_models"]
+        total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
+        total_time_training_models = state["total_time_training_models"]
     except FileNotFoundError:
         log_string(f"cannot find {filename}")
         pass
@@ -1161,7 +1161,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     if (
         min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
-        and total_time_training_models > total_time_generating_c_quizzes
+        and total_time_training_models >= total_time_generating_c_quizzes
     ):
         for model in models:
             model.current_dict = copy.deepcopy(model.state_dict())
@@ -1183,7 +1183,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ##################################################
     # Select, improve, and eval the worst model(s)
 
-    if total_time_training_models <= total_time_generating_c_quizzes:
+    if total_time_training_models < total_time_generating_c_quizzes:
         ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
 
         weakest_models = ranked_models[: len(gpus)]
@@ -1191,6 +1191,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         threads = []
 
         start_time = time.perf_counter()
+
         for gpu, model in zip(gpus, weakest_models):
             log_string(f"training model {model.id}")