Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 18:38:42 +0000 (20:38 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 18:38:42 +0000 (20:38 +0200)
main.py

diff --git a/main.py b/main.py
index 1cbff39..cd1e10f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -93,11 +93,7 @@ parser.add_argument("--gpus", type=str, default="all")
 
 # ----------------------------------
 
-parser.add_argument("--nb_gpts", type=int, default=5)
-
-parser.add_argument("--min_succeed_to_validate", type=int, default=2)
-
-parser.add_argument("--max_fail_to_validate", type=int, default=3)
+parser.add_argument("--nb_gpts", type=int, default=2)
 
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
@@ -340,8 +336,9 @@ def run_tests(model, quiz_machine, local_device=main_device):
         nb_samples_accumulated = 0
 
         full_input, full_mask_loss = quiz_machine.data_input(
-            args.nb_test_samples, model.test_c_quiz_bags
+            args.nb_test_samples, test_c_quiz_bags
         )
+
         src = zip(
             full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
         )
@@ -368,7 +365,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
 
-        input, _ = quiz_machine.data_input(2000, model.test_c_quiz_bags)
+        input, _ = quiz_machine.data_input(1000, test_c_quiz_bags)
 
         model.test_accuracy = quiz_machine.produce_results(
             n_epoch=n_epoch,
@@ -391,7 +388,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
     nb_train_samples, acc_train_loss = 0, 0.0
 
     full_input, full_mask_loss = quiz_machine.data_input(
-        args.nb_train_samples, model.train_c_quiz_bags
+        args.nb_train_samples, train_c_quiz_bags
     )
     src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
 
@@ -557,7 +554,15 @@ def model_proba_solutions(model, quizzes):
     return l.exp()
 
 
-def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_for_test):
+def create_c_quizzes(
+    main_model,
+    other_models,
+    quiz_machine,
+    nb_for_train,
+    train_c_quiz_bags,
+    nb_for_test,
+    test_c_quiz_bags,
+):
     nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
     nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
 
@@ -641,7 +646,7 @@ def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_fo
             e = "???"
 
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%"
+            f"keep c_quizzes model {main_model.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%"
         )
 
     # Save some images
@@ -661,10 +666,9 @@ def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_fo
         args.result_dir, filename, c_quizzes[:128], comments=comments
     )
 
-
-log_string(
-    f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
-)
+    log_string(
+        f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in train_c_quiz_bags ])} test {sum([q.size(0) for q in test_c_quiz_bags ])}"
+    )
 
 
 ######################################################################
@@ -709,8 +713,6 @@ for k in range(args.nb_gpts):
         )
 
     model.id = k
-    model.train_c_quiz_bags = []
-    model.test_c_quiz_bags = []
 
     if args.schedule_free:
         model.optimizer = schedulefree.AdamWScheduleFree(
@@ -724,6 +726,9 @@ for k in range(args.nb_gpts):
 
 ######################################################################
 
+train_c_quiz_bags = []
+test_c_quiz_bags = []
+
 current_epoch = 0
 
 if args.resume:
@@ -735,8 +740,6 @@ if args.resume:
             model.load_state_dict(d["state_dict"])
             model.optimizer.load_state_dict(d["optimizer_state_dict"])
             model.test_accuracy = d["test_accuracy"]
-            model.train_c_quiz_bags = d["train_c_quiz_bags"]
-            model.test_c_quiz_bags = d["test_c_quiz_bags"]
             log_string(f"successfully loaded {filename}")
         except FileNotFoundError:
             log_string(f"cannot find {filename}")
@@ -747,6 +750,8 @@ if args.resume:
         state = torch.load(os.path.join(args.result_dir, filename))
         log_string(f"successfully loaded {filename}")
         current_epoch = state["current_epoch"]
+        train_c_quiz_bags = d["train_c_quiz_bags"]
+        test_c_quiz_bags = d["test_c_quiz_bags"]
     except FileNotFoundError:
         log_string(f"cannot find {filename}")
         pass
@@ -759,10 +764,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 ######################################################################
 
 if args.nb_new_c_quizzes_for_train is None:
-    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 250
+    args.nb_new_c_quizzes_for_train = args.nb_train_samples
 
 if args.nb_new_c_quizzes_for_test is None:
-    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 250
+    args.nb_new_c_quizzes_for_test = args.nb_test_samples
 
 log_string(
     f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
@@ -850,6 +855,8 @@ def save_generated_c_quizzes(model, filename, nb=64):
 for n_epoch in range(current_epoch, args.nb_epochs):
     state = {
         "current_epoch": n_epoch,
+        "train_c_quiz_bags": train_c_quiz_bags,
+        "test_c_quiz_bags": test_c_quiz_bags,
     }
     filename = "state.pth"
     torch.save(state, os.path.join(args.result_dir, filename))
@@ -863,11 +870,14 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     ##################################################
 
     if min([m.test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        record_new_c_quizzes(
-            models,
-            quiz_machine,
-            args.nb_new_c_quizzes_for_train,
-            args.nb_new_c_quizzes_for_test,
+        create_c_quizzes(
+            main_model=models[0],
+            other_models=models[1:],
+            quiz_machine=quiz_machine,
+            nb_for_train=args.nb_new_c_quizzes_for_train,
+            train_c_quiz_bags=train_c_quiz_bags,
+            nb_for_test=args.nb_new_c_quizzes_for_test,
+            test_c_quiz_bags=test_c_quiz_bags,
         )
 
         for model in models:
@@ -883,8 +893,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             ).to(main_device)
             model.load_state_dict(new_model.state_dict())
             model.test_accuracy = 0.0
-            model.best_test_accuracy = 0.0
-            model.best_dict = copy.deepcopy(model.state_dict())
 
     ##################################################
     # Select, improve, and eval the worst model(s)
@@ -894,11 +902,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         # This ugly recipe will pick the worst if there some below
         # args.accuracy_to_make_c_quizzes or one at random if they
         # are all above
-        key=lambda m: float(
-            m.test_accuracy
-            if m.test_accuracy < args.accuracy_to_make_c_quizzes
-            else args.accuracy_to_make_c_quizzes + torch.rand(1).item()
-        ),
+        key=lambda m: float(m.test_accuracy),
     )
 
     weakest_models = ranked_models[: len(gpus)]
@@ -921,8 +925,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     for t in threads:
         t.join()
 
-    total_time_training_models += time.perf_counter() - start_time
-
     for model in weakest_models:
         save_additional_results(n_epoch, model, models, c_quizzes_procedure)
 
@@ -935,10 +937,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
                 "state_dict": model.state_dict(),
                 "optimizer_state_dict": model.optimizer.state_dict(),
                 "test_accuracy": model.test_accuracy,
-                "best_test_accuracy": model.best_test_accuracy,
-                "best_dict": model.best_dict,
-                "train_c_quiz_bags": model.train_c_quiz_bags,
-                "test_c_quiz_bags": model.test_c_quiz_bags,
             },
             os.path.join(args.result_dir, filename),
         )