Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 16:37:02 +0000 (18:37 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 31 Aug 2024 16:37:02 +0000 (18:37 +0200)
main.py

diff --git a/main.py b/main.py
index 8726c96..c11f5c2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -95,6 +95,8 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_models", type=int, default=5)
 
+parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
+
 parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
 parser.add_argument("--max_fail_to_validate", type=int, default=3)
@@ -309,6 +311,17 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
+
+def bag_len(bag):
+    return sum([x[0].size(0) for x in bag])
+
+
+def bag_to_tensors(bag):
+    return tuple(torch.cat([x[i] for x in bag], dim=0) for i in range(len(bag[0])))
+
+
+######################################################################
+
 # If we need to move an optimizer to a different device
 
 
@@ -945,10 +958,6 @@ class FunctionalAE(nn.Module):
 
 ######################################################################
 
-nb_iterations = 25
-probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device)
-probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
-
 
 def ae_batches(
     quiz_machine,
@@ -1024,6 +1033,8 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50):
     changed = True
 
     for it in range(nb_iterations_max):
+        print(f"{it=} {nb_iterations_max=}")
+
         input_with_mask = NTC_channel_cat(input, mask_generate)
         logits = model(input_with_mask)
         dist = torch.distributions.categorical.Categorical(logits=logits)
@@ -1044,7 +1055,8 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50):
             changed = changed & (update != input).max(dim=1).values
             input[changed] = update[changed]
 
-    log_string(f"remains {changed.long().sum()}")
+    if it == nb_iterations_max:
+        log_string(f"remains {changed.long().sum()}")
 
     return input
 
@@ -1062,9 +1074,7 @@ def model_ae_proba_solutions(model, input):
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
-            targets, logits = targets_and_prediction(
-                probs_iterations, model, q, mask_generate
-            )
+            targets, logits = targets_and_prediction(model, q, mask_generate)
             loss_per_token = F.cross_entropy(
                 logits.transpose(1, 2), targets, reduction="none"
             )
@@ -1076,6 +1086,9 @@ def model_ae_proba_solutions(model, input):
     return (-loss).exp()
 
 
+nb_diffusion_iterations = 25
+
+
 def degrade_input(input, mask_generate, nb_iterations, noise_proba):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
@@ -1094,14 +1107,18 @@ def degrade_input(input, mask_generate, nb_iterations, noise_proba):
     return result
 
 
-def targets_and_prediction(probs_iterations, model, input, mask_generate):
+def targets_and_prediction(model, input, mask_generate):
     d = deterministic(mask_generate)
-    p = probs_iterations.expand(input.size(0), -1)
-    dist = torch.distributions.categorical.Categorical(probs=p)
+    probs_iterations = 0.1 ** torch.linspace(
+        0, 1, args.nb_diffusion_iterations, device=input.device
+    )
+    probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+    probs_iterations = probs_iterations.expand(input.size(0), -1)
+    dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
     N0 = dist.sample()
     N1 = N0 + 1
     N0 = (1 - d) * N0
-    N1 = (1 - d) * N1 + d * nb_iterations
+    N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
 
     targets, input = degrade_input(
         input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
@@ -1113,7 +1130,7 @@ def targets_and_prediction(probs_iterations, model, input, mask_generate):
     return targets, logits
 
 
-def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_device):
+def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
     with torch.autograd.no_grad():
         model.eval().to(local_device)
 
@@ -1128,9 +1145,7 @@ def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_de
             local_device,
             "test",
         ):
-            targets, logits = targets_and_prediction(
-                probs_iterations, model, input, mask_generate
-            )
+            targets, logits = targets_and_prediction(model, input, mask_generate)
             loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
@@ -1173,28 +1188,27 @@ def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_de
 
         model.test_accuracy = nb_correct / nb_total
 
-        for f, record in [("prediction", record_d), ("generation", record_nd)]:
-            filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
-            result, predicted_parts, correct_parts = (
-                torch.cat([x[i] for x in record])[:128] for i in [0, 1, 2]
-            )
+        # for f, record in [("prediction", record_d), ("generation", record_nd)]:
+        # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
 
-            l = [model_ae_proba_solutions(model, result) for model in other_models]
-            probas = torch.cat([x[:, None] for x in l], dim=1)
-            comments = []
+        # result, predicted_parts, correct_parts = bag_to_tensors(record)
 
-            for l in probas:
-                comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+        # l = [model_ae_proba_solutions(model, result) for model in other_models]
+        # probas = torch.cat([x[:, None] for x in l], dim=1)
+        # comments = []
 
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir,
-                filename,
-                quizzes=result,
-                predicted_parts=predicted_parts,
-                correct_parts=correct_parts,
-                comments=comments,
-            )
-            log_string(f"wrote {filename}")
+        # for l in probas:
+        # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+
+        # quiz_machine.problem.save_quizzes_as_image(
+        # args.result_dir,
+        # filename,
+        # quizzes=result,
+        # predicted_parts=predicted_parts,
+        # correct_parts=correct_parts,
+        # comments=comments,
+        # )
+        # log_string(f"wrote {filename}")
 
         # Prediction with functional perturbations
 
@@ -1237,12 +1251,19 @@ def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_d
         local_device,
         "training",
     ):
+        input = input.to(local_device)
+        mask_generate = mask_generate.to(local_device)
+        mask_loss = mask_loss.to(local_device)
+
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
-        targets, logits = targets_and_prediction(
-            probs_iterations, model, input, mask_generate
+        targets, logits = targets_and_prediction(model, input, mask_generate)
+
+        print(
+            f"{input.device=} {logits.device=} {targets.device=} {logits.device=} {mask_loss.device=}"
         )
+
         loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
@@ -1256,7 +1277,7 @@ def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_d
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
     )
 
-    run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device)
+    run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device)
 
 
 ######################################################################
@@ -1288,6 +1309,83 @@ for i in range(args.nb_models):
 
 ######################################################################
 
+
+def c_quiz_criterion_one_good_one_bad(probas):
+    return (probas.max(dim=1).values >= 0.8) & (probas.min(dim=1).values <= 0.2)
+
+
+def c_quiz_criterion_diff(probas):
+    return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5
+
+
+def c_quiz_criterion_two_certains(probas):
+    return ((probas >= 0.99).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.5)
+
+
+def generate_ae_c_quizzes(models, local_device=main_device):
+    criteria = [
+        c_quiz_criterion_one_good_one_bad,
+        c_quiz_criterion_diff,
+        c_quiz_criterion_two_certains,
+    ]
+
+    for m in models:
+        m.eval().to(local_device)
+
+    quad_order = ("A", "f_A", "B", "f_B")
+
+    template = quiz_machine.problem.create_empty_quizzes(
+        nb=args.batch_size, quad_order=quad_order
+    ).to(local_device)
+
+    mask_generate = quiz_machine.make_quiz_mask(
+        quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
+    )
+
+    records = [[] for _ in criteria]
+
+    with torch.autograd.no_grad():
+        while min([bag_len(bag) for bag in records]) < 128:
+            bl = [bag_len(bag) for bag in records]
+            log_string(f"bag_len {bl}")
+
+            model = models[torch.randint(len(models), (1,)).item()]
+            result = ae_generate(model, template, mask_generate, 0.0)
+
+            probas = torch.cat(
+                [model_ae_proba_solutions(model, result)[:, None] for model in models],
+                dim=1,
+            )
+            for c, r in zip(criteria, records):
+                q = result[c(probas)]
+                if q.size(0) > 0:
+                    r.append(q)
+
+    # for f, record in [("prediction", record_d), ("generation", record_nd)]:
+    # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+
+    # result, predicted_parts, correct_parts = bag_to_tensors(record)
+
+    # l = [model_ae_proba_solutions(model, result) for model in models]
+    # probas = torch.cat([x[:, None] for x in l], dim=1)
+    # comments = []
+
+    # for l in probas:
+    # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+
+    # quiz_machine.problem.save_quizzes_as_image(
+    # args.result_dir,
+    # filename,
+    # quizzes=result,
+    # predicted_parts=predicted_parts,
+    # correct_parts=correct_parts,
+    # comments=comments,
+    # )
+    # log_string(f"wrote {filename}")
+
+
+######################################################################
+
 current_epoch = 0
 
 if args.resume:
@@ -1374,6 +1472,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     for t in threads:
         t.join()
 
+    generate_ae_c_quizzes(models, local_device=main_device)
+
     # --------------------------------------------------------------------
 
     for model in models: