Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 16 Aug 2024 18:22:00 +0000 (20:22 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 16 Aug 2024 18:22:00 +0000 (20:22 +0200)
main.py

diff --git a/main.py b/main.py
index e203a71..133f536 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -543,7 +543,7 @@ def model_proba_solutions(model, quizzes):
 
 def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
     nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
-    nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate // 10
+    nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
 
     start_time = time.perf_counter()
 
@@ -919,6 +919,10 @@ for k in range(args.nb_gpts):
 ######################################################################
 
 current_epoch = 0
+
+# We balance the computing time between training the models and
+# generating c_quizzes
+
 total_time_generating_c_quizzes = 0
 total_time_training_models = 0
 
@@ -945,8 +949,8 @@ if args.resume:
         state = torch.load(os.path.join(args.result_dir, filename))
         log_string(f"successfully loaded {filename}")
         current_epoch = state["current_epoch"]
-        total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
-        total_time_training_models = state["total_time_training_models"]
+        total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
+        total_time_training_models = state["total_time_training_models"]
     except FileNotFoundError:
         log_string(f"cannot find {filename}")
         pass