Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 3 Sep 2024 13:06:37 +0000 (15:06 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 3 Sep 2024 13:06:37 +0000 (15:06 +0200)
main.py

diff --git a/main.py b/main.py
index fb0b4df..2376868 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -53,7 +53,7 @@ parser.add_argument("--physical_batch_size", type=int, default=None)
 
 parser.add_argument("--inference_batch_size", type=int, default=25)
 
-parser.add_argument("--nb_train_samples", type=int, default=40000)
+parser.add_argument("--nb_train_samples", type=int, default=25000)
 
 parser.add_argument("--nb_test_samples", type=int, default=1000)
 
@@ -61,7 +61,7 @@ parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
 
 parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
 
-parser.add_argument("--c_quiz_multiplier", type=int, default=4)
+parser.add_argument("--c_quiz_multiplier", type=int, default=10)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
@@ -1079,7 +1079,9 @@ def model_ae_proba_solutions(model, input, log_proba=False):
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            targets, logits = targets_and_prediction(model, q, mask_generate)
+            targets, logits = targets_and_prediction(
+                model, q, mask_generate, prompt_noise=args.prompt_noise
+            )
             loss_per_token = F.cross_entropy(
                 logits.transpose(1, 2), targets, reduction="none"
             )
@@ -1400,7 +1402,7 @@ def save_badness_statistics(
     log_string(f"wrote {filename}")
 
 
-def generate_ae_c_quizzes(models, local_device=main_device):
+def generate_ae_c_quizzes(models, nb, local_device=main_device):
     criteria = [
         # c_quiz_criterion_only_one,
         c_quiz_criterion_one_good_one_bad,
@@ -1411,8 +1413,8 @@ def generate_ae_c_quizzes(models, local_device=main_device):
         # c_quiz_criterion_some,
     ]
 
-    for m in models:
-        m.eval().to(local_device)
+    # To be thread-safe we must make copies
+    models = [copy.deepcopy(model).to(local_device) for model in models]
 
     quad_order = ("A", "f_A", "B", "f_B")
 
@@ -1426,7 +1428,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
 
     duration_max = 4 * 3600
 
-    wanted_nb = args.nb_train_samples // args.c_quiz_multiplier
+    wanted_nb = nb
     nb_to_save = 256
 
     with torch.autograd.no_grad():
@@ -1515,6 +1517,10 @@ def generate_ae_c_quizzes(models, local_device=main_device):
     return torch.cat(a, dim=0).unique(dim=0)
 
 
+def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
+    record.append(generate_ae_c_quizzes(models, nb, local_device))
+
+
 ######################################################################
 
 current_epoch = 0
@@ -1600,9 +1606,37 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             save_badness_statistics(last_n_epoch_c_quizzes, models, c_quizzes, "after")
 
         last_n_epoch_c_quizzes = n_epoch
+        nb_c_quizzes_to_generate = args.nb_train_samples // args.c_quiz_multiplier
+
+        # --------------------------------------------------------------------
+
+        records, threads = [], []
+
         start_time = time.perf_counter()
-        c_quizzes = generate_ae_c_quizzes(models, local_device=main_device)
+
+        for gpu in gpus:
+            t = threading.Thread(
+                target=thread_generate_ae_c_quizzes,
+                daemon=True,
+                args=(models, nb_c_quizzes_to_generate, records, gpu),
+            )
+
+            # To get a different sequence between threads
+            log_string(f"dummy {torch.rand(1)}")
+            threads.append(t)
+            t.start()
+
+        for t in threads:
+            t.join()
+
         time_c_quizzes = time.perf_counter() - start_time
+
+        c_quizzes = torch.cat([q.to(main_device) for q in records], dim=0)
+
+        # --------------------------------------------------------------------
+
+        log_string(f"generated_c_quizzes {c_quizzes.size()=}")
+
         time_train = 0
         for model in models:
             model.test_accuracy = 0
@@ -1614,6 +1648,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     else:
         log_string(f"nb_c_quizzes {c_quizzes.size(0)}")
 
+    # --------------------------------------------------------------------
+
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]