Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 21:53:33 +0000 (23:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 21:53:33 +0000 (23:53 +0200)
main.py

diff --git a/main.py b/main.py
index 879d9fd..34bf920 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -964,11 +964,14 @@ def ae_batches(
     nb,
     data_structures,
     local_device,
+    c_quizzes=None,
     desc=None,
     batch_size=args.batch_size,
 ):
+    c_quiz_bags = [] if c_quizzes is None else [c_quizzes.to("cpu")]
+
     full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
-        nb, data_structures=data_structures
+        nb, c_quiz_bags, data_structures=data_structures
     )
 
     src = zip(
@@ -1237,7 +1240,9 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
 ######################################################################
 
 
-def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_device):
+def one_ae_epoch(
+    model, other_models, quiz_machine, n_epoch, c_quizzes, local_device=main_device
+):
     model.train().to(local_device)
 
     nb_train_samples, acc_train_loss = 0, 0.0
@@ -1247,6 +1252,7 @@ def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_d
         args.nb_train_samples,
         data_structures,
         local_device,
+        c_quizzes,
         "training",
     ):
         input = input.to(local_device)
@@ -1325,9 +1331,9 @@ def c_quiz_criterion_some(probas):
 def generate_ae_c_quizzes(models, local_device=main_device):
     criteria = [
         c_quiz_criterion_one_good_one_bad,
-        c_quiz_criterion_diff,
-        c_quiz_criterion_two_certains,
-        c_quiz_criterion_some,
+        c_quiz_criterion_diff,
+        c_quiz_criterion_two_certains,
+        c_quiz_criterion_some,
     ]
 
     for m in models:
@@ -1343,9 +1349,10 @@ def generate_ae_c_quizzes(models, local_device=main_device):
         quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
     )
 
-    duration_max = 3600
+    duration_max = 4 * 3600
 
-    wanted_nb = 512
+    wanted_nb = 10000
+    nb_to_save = 128
 
     with torch.autograd.no_grad():
         records = [[] for _ in criteria]
@@ -1386,7 +1393,7 @@ def generate_ae_c_quizzes(models, local_device=main_device):
         )
 
         for n, u in enumerate(records):
-            quizzes = torch.cat(u, dim=0)[:wanted_nb]
+            quizzes = torch.cat(u, dim=0)[:nb_to_save]
             filename = f"culture_c_{n_epoch:04d}_{n:02d}.png"
 
             # result, predicted_parts, correct_parts = bag_to_tensors(record)
@@ -1405,11 +1412,14 @@ def generate_ae_c_quizzes(models, local_device=main_device):
                 # predicted_parts=predicted_parts,
                 # correct_parts=correct_parts,
                 comments=comments,
-                nrow=8,
             )
 
             log_string(f"wrote {filename}")
 
+    a = [torch.cat(u, dim=0) for u in records]
+
+    return torch.cat(a, dim=0).unique(dim=0)
+
 
 ######################################################################
 
@@ -1453,6 +1463,10 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
+last_n_epoch_c_quizzes = 0
+
+c_quizzes = None
+
 for n_epoch in range(current_epoch, args.nb_epochs):
     start_time = time.perf_counter()
 
@@ -1476,8 +1490,17 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
     # exit(0)
 
-    if min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes:
-        generate_ae_c_quizzes(models, local_device=main_device)
+    if (
+        min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes
+        and n_epoch >= last_n_epoch_c_quizzes + 10
+    ):
+        last_n_epoch_c_quizzes = n_epoch
+        c_quizzes = generate_ae_c_quizzes(models, local_device=main_device)
+
+    if c_quizzes is None:
+        log_string("no_c_quiz")
+    else:
+        log_string(f"nb_c_quizzes {c_quizzes.size(0)}")
 
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
@@ -1492,7 +1515,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         t = threading.Thread(
             target=one_ae_epoch,
             daemon=True,
-            args=(model, models, quiz_machine, n_epoch, gpu),
+            args=(model, models, quiz_machine, n_epoch, c_quizzes, gpu),
         )
 
         threads.append(t)