Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 16:53:34 +0000 (18:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 16:53:34 +0000 (18:53 +0200)
main.py

diff --git a/main.py b/main.py
index c11f5c2..ab625cc 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -313,7 +313,7 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 
 def bag_len(bag):
-    return sum([x[0].size(0) for x in bag])
+    return sum([x.size(0) for x in bag])
 
 
 def bag_to_tensors(bag):
@@ -1033,8 +1033,6 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50):
     changed = True
 
     for it in range(nb_iterations_max):
-        print(f"{it=} {nb_iterations_max=}")
-
         input_with_mask = NTC_channel_cat(input, mask_generate)
         logits = model(input_with_mask)
         dist = torch.distributions.categorical.Categorical(logits=logits)
@@ -1260,10 +1258,6 @@ def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_d
 
         targets, logits = targets_and_prediction(model, input, mask_generate)
 
-        print(
-            f"{input.device=} {logits.device=} {targets.device=} {logits.device=} {mask_loss.device=}"
-        )
-
         loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
@@ -1342,46 +1336,56 @@ def generate_ae_c_quizzes(models, local_device=main_device):
         quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
     )
 
-    records = [[] for _ in criteria]
+    duration_max = 600  # 3 * 3600
 
     with torch.autograd.no_grad():
-        while min([bag_len(bag) for bag in records]) < 128:
+        records = [[] for _ in criteria]
+
+        start_time = time.perf_counter()
+
+        while (
+            time.perf_counter() < start_time + duration_max
+            and min([bag_len(bag) for bag in records]) < 128
+        ):
             bl = [bag_len(bag) for bag in records]
             log_string(f"bag_len {bl}")
 
             model = models[torch.randint(len(models), (1,)).item()]
-            result = ae_generate(model, template, mask_generate, 0.0)
+            result = ae_generate(model, template, mask_generate, noise_proba)
 
             probas = torch.cat(
                 [model_ae_proba_solutions(model, result)[:, None] for model in models],
                 dim=1,
             )
+
             for c, r in zip(criteria, records):
                 q = result[c(probas)]
                 if q.size(0) > 0:
                     r.append(q)
 
-    # for f, record in [("prediction", record_d), ("generation", record_nd)]:
-    # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+    for n, u in enumerate(records):
+        quizzes = torch.cat(u, dim=0)[:128]
+        filename = f"culture_{n_epoch:04d}_{n:02d}.png"
 
-    # result, predicted_parts, correct_parts = bag_to_tensors(record)
+        # result, predicted_parts, correct_parts = bag_to_tensors(record)
 
-    # l = [model_ae_proba_solutions(model, result) for model in models]
-    # probas = torch.cat([x[:, None] for x in l], dim=1)
-    # comments = []
+        # l = [model_ae_proba_solutions(model, result) for model in models]
+        # probas = torch.cat([x[:, None] for x in l], dim=1)
+        # comments = []
+
+        # for l in probas:
+        # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
 
-    # for l in probas:
-    # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir,
+            filename,
+            quizzes=result,
+            # predicted_parts=predicted_parts,
+            # correct_parts=correct_parts,
+            # comments=comments,
+        )
 
-    # quiz_machine.problem.save_quizzes_as_image(
-    # args.result_dir,
-    # filename,
-    # quizzes=result,
-    # predicted_parts=predicted_parts,
-    # correct_parts=correct_parts,
-    # comments=comments,
-    # )
-    # log_string(f"wrote {filename}")
+        log_string(f"wrote {filename}")
 
 
 ######################################################################
@@ -1449,6 +1453,9 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
     # exit(0)
 
+    if min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes:
+        generate_ae_c_quizzes(models, local_device=main_device)
+
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
 
@@ -1472,8 +1479,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     for t in threads:
         t.join()
 
-    generate_ae_c_quizzes(models, local_device=main_device)
-
     # --------------------------------------------------------------------
 
     for model in models: