Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 5 Sep 2024 09:50:34 +0000 (11:50 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 5 Sep 2024 09:50:34 +0000 (11:50 +0200)
main.py

diff --git a/main.py b/main.py
index 934940e..e95c4f6 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -839,7 +839,28 @@ def model_ae_proba_solutions(model, input, log_proba=False):
         return (-loss).exp()
 
 
-nb_diffusion_iterations = 25
+def model_ae_argmax_nb_disagreements(model, input):
+    record = []
+
+    for q in input.split(args.batch_size):
+        nb_disagreements = 0
+        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+            mask_generate = quiz_machine.make_quiz_mask(
+                quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+            )
+            targets, logits = targets_and_prediction(
+                model, q, mask_generate, prompt_noise=args.prompt_noise
+            )
+
+            predicted = logits.argmax(dim=-1)
+
+            nb_disagreements = nb_disagreements + (
+                mask_generate * predicted != mask_generate * targets
+            ).long().sum(dim=1)
+
+        record.append(nb_disagreements)
+
+    return torch.cat(record, dim=0)
 
 
 def degrade_input_to_generate(input, mask_generate, steps_nb_iterations):
@@ -1152,20 +1173,7 @@ def c_quiz_criterion_some(probas):
 
 
 def generate_ae_c_quizzes(models, nb, local_device=main_device):
-    criteria = [
-        c_quiz_criterion_few_good_one_bad,
-        # c_quiz_criterion_only_one,
-        # c_quiz_criterion_one_good_one_bad,
-        # c_quiz_criterion_one_good_no_very_bad,
-        # c_quiz_criterion_diff,
-        # c_quiz_criterion_diff2,
-        # c_quiz_criterion_two_good,
-        # c_quiz_criterion_some,
-    ]
-
     # To be thread-safe we must make copies
-    models = [copy.deepcopy(model).to(local_device) for model in models]
-
     quad_order = ("A", "f_A", "B", "f_B")
 
     template = quiz_machine.problem.create_empty_quizzes(
@@ -1176,45 +1184,57 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
         quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
     )
 
-    duration_max = 4 * 3600
+    def copy_for_inference(model):
+        return copy.deepcopy(model).to(local_device).eval()
 
     wanted_nb = nb
     nb_to_save = 256
 
     with torch.autograd.no_grad():
-        records = [[] for _ in criteria]
+        records = []
 
         last_log = -1
         start_time = time.perf_counter()
 
-        while (
-            time.perf_counter() < start_time + duration_max
-            and min([bag_len(bag) for bag in records]) < wanted_nb
-        ):
-            model = models[torch.randint(len(models), (1,)).item()]
+        while bag_len(records) < wanted_nb:
+            model = copy_for_inference(models[torch.randint(len(models), (1,)).item()])
+
             c_quizzes = ae_generate(model, template, mask_generate)
 
             to_keep = quiz_machine.problem.trivial(c_quizzes) == False
             c_quizzes = c_quizzes[to_keep]
 
             if c_quizzes.size(0) > 0:
-                probas = torch.cat(
-                    [
-                        model_ae_proba_solutions(model, c_quizzes)[:, None]
-                        for model in models
-                    ],
-                    dim=1,
-                )
+                # p = [
+                # model_ae_proba_solutions(model, c_quizzes)[:, None]
+                # for model in models
+                # ]
+
+                # probas = torch.cat(p, dim=1)
+                # to_keep = c_quiz_criterion_two_good(probas)
 
-                for c, r in zip(criteria, records):
-                    q = c_quizzes[c(probas)]
-                    if q.size(0) > 0:
-                        r.append(q)
+                nb_disagreements = []
+                for model in models:
+                    model = copy_for_inference(model)
+                    nb_disagreements.append(
+                        model_ae_argmax_nb_disagreements(model, c_quizzes).long()[
+                            :, None
+                        ]
+                    )
+                nb_disagreements = torch.cat(nb_disagreements, dim=1)
+
+                v = nb_disagreements.sort(dim=1).values
+                to_keep = (v[:, 1] == 0) & (v[:, -1] > 3)
+
+                q = c_quizzes[to_keep]
+
+                if q.size(0) > 0:
+                    records.append(q)
 
             duration = time.perf_counter() - start_time
-            nb_generated = min([bag_len(bag) for bag in records])
+            nb_generated = bag_len(records)
 
-            if last_log < 0 or duration > last_log + 60:
+            if last_log < 0 or duration > last_log + 5:
                 last_log = duration
                 if nb_generated > 0:
                     if nb_generated < wanted_nb:
@@ -1227,44 +1247,46 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
                 else:
                     e = "???"
 
-                bl = [bag_len(bag) for bag in records]
                 log_string(
-                    f"bag_len {bl} model {model.id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
+                    f"nb_generated {bag_len(records)} model {model.id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
                 )
 
         duration = time.perf_counter() - start_time
 
         log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
 
-        for n, u in enumerate(records):
-            quizzes = torch.cat(u, dim=0)[:nb_to_save]
-            filename = f"culture_c_quiz_{n_epoch:04d}_{n:02d}.png"
+        c_quizzes = torch.cat(records, dim=0).unique(dim=0)
 
-            # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record)
+        subset_c_quizzes = c_quizzes[:nb_to_save]
 
-            l = [model_ae_proba_solutions(model, quizzes) for model in models]
-            probas = torch.cat([x[:, None] for x in l], dim=1)
-            comments = []
+        filename = f"culture_c_quiz_{n_epoch:04d}.png"
 
-            for l in probas:
-                comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
+        # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record)
 
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir,
-                filename,
-                quizzes=quizzes,
-                # predicted_parts=predicted_parts,
-                # correct_parts=correct_parts,
-                comments=comments,
-                delta=True,
-                nrow=8,
-            )
+        l = [
+            model_ae_proba_solutions(copy_for_inference(model), subset_c_quizzes)
+            for model in models
+        ]
+        probas = torch.cat([x[:, None] for x in l], dim=1)
+        comments = []
 
-            log_string(f"wrote {filename}")
+        for l in probas:
+            comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
 
-    a = [torch.cat(u, dim=0) for u in records]
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir,
+            filename,
+            quizzes=subset_c_quizzes,
+            # predicted_parts=predicted_parts,
+            # correct_parts=correct_parts,
+            comments=comments,
+            delta=True,
+            nrow=8,
+        )
+
+        log_string(f"wrote {filename}")
 
-    return torch.cat(a, dim=0).unique(dim=0)
+    return c_quizzes
 
 
 def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):