Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 21 Aug 2024 20:29:53 +0000 (22:29 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 21 Aug 2024 20:29:53 +0000 (22:29 +0200)
main.py

diff --git a/main.py b/main.py
index 410ce7d..66b7ffd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1089,55 +1089,66 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     if total_time_generating_c_quizzes == 0:
         total_time_training_models = 0
 
-    if (
-        min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
-        and total_time_training_models >= total_time_generating_c_quizzes
-    ):
-        ######################################################################
-        # Re-initalize if there are enough culture quizzes
-
+    if min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
         if args.reboot:
-            nb_c_quizzes_per_model = [
-                sum([x.size(0) for x in model.train_c_quiz_bags]) for model in models
-            ]
+            for model in models:
+                model.current_dict = copy.deepcopy(model.state_dict())
+                model.load_state_dict(model.gen_state_dict)
+
+            while True:
+                record_new_c_quizzes(
+                    models,
+                    quiz_machine,
+                    args.nb_new_c_quizzes_for_train,
+                    args.nb_new_c_quizzes_for_test,
+                )
 
-            p = tuple(
-                f"{(x*100)/args.nb_train_samples:.02f}%" for x in nb_c_quizzes_per_model
-            )
+                nb_c_quizzes_per_model = [
+                    sum([x.size(0) for x in model.train_c_quiz_bags])
+                    for model in models
+                ]
 
-            log_string(f"nb_c_quizzes_per_model {p}")
+                p = tuple(
+                    f"{(x*100)/args.nb_train_samples:.02f}%"
+                    for x in nb_c_quizzes_per_model
+                )
 
-            m = max(nb_c_quizzes_per_model)
+                log_string(f"nb_c_quizzes_per_model {p}")
 
-            if m >= args.nb_train_samples:
-                model = models[nb_c_quizzes_per_model.index(m)]
-                common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0))
-                nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags])
-                log_string(
-                    f"rebooting the models with {nb_common_c_quizzes} culture quizzes"
-                )
+                m = max(nb_c_quizzes_per_model)
 
-                models = create_models()
-                total_time_generating_c_quizzes = 0
-                total_time_training_models = 0
+                if m >= args.nb_train_samples:
+                    break
 
-        for model in models:
-            model.current_dict = copy.deepcopy(model.state_dict())
-            model.load_state_dict(model.gen_state_dict)
+            model = models[nb_c_quizzes_per_model.index(m)]
+            common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0))
+            nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags])
+            log_string(
+                f"rebooting the models with {nb_common_c_quizzes} culture quizzes"
+            )
 
-        start_time = time.perf_counter()
+            models = create_models()
+            total_time_generating_c_quizzes = 0
+            total_time_training_models = 0
 
-        record_new_c_quizzes(
-            models,
-            quiz_machine,
-            args.nb_new_c_quizzes_for_train,
-            args.nb_new_c_quizzes_for_test,
-        )
+        elif total_time_training_models >= total_time_generating_c_quizzes:
+            for model in models:
+                model.current_dict = copy.deepcopy(model.state_dict())
+                model.load_state_dict(model.gen_state_dict)
 
-        total_time_generating_c_quizzes += time.perf_counter() - start_time
+            start_time = time.perf_counter()
 
-        for model in models:
-            model.load_state_dict(model.current_dict)
+            record_new_c_quizzes(
+                models,
+                quiz_machine,
+                args.nb_new_c_quizzes_for_train,
+                args.nb_new_c_quizzes_for_test,
+            )
+
+            total_time_generating_c_quizzes += time.perf_counter() - start_time
+
+            for model in models:
+                model.load_state_dict(model.current_dict)
 
     ##################################################
     # Select, improve, and eval the worst model(s)