Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 5 Sep 2024 20:44:02 +0000 (22:44 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 5 Sep 2024 20:44:02 +0000 (22:44 +0200)
main.py

diff --git a/main.py b/main.py
index 8e938db..f609fd8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -881,6 +881,32 @@ def model_ae_argmax_nb_disagreements(model, input):
     return torch.cat(record, dim=0)
 
 
+######################################################################
+
+
+def model_ae_argmax_predictions(model, input):
+    result = input.clone()
+    # result[...] = 0
+
+    for r, q in zip(result.split(args.batch_size), input.split(args.batch_size)):
+        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+            mask_generate = quiz_machine.make_quiz_mask(
+                quizzes=q, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+            )
+            targets, logits = targets_and_prediction(
+                model, q, mask_generate, prompt_noise=args.prompt_noise
+            )
+
+            predicted = logits.argmax(dim=-1)
+
+            r[...] = (1 - mask_generate) * r + mask_generate * predicted
+
+    return result
+
+
+######################################################################
+
+
 def degrade_input_to_generate(input, mask_generate, steps_nb_iterations):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
@@ -942,7 +968,9 @@ def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0):
 def run_ae_test(
     model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_device, prefix=None
 ):
-    if prefix is not None:
+    if prefix is None:
+        prefix = ""
+    else:
         prefix = prefix + "_"
 
     with torch.autograd.no_grad():
@@ -1216,14 +1244,15 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
 
     wanted_nb = nb
     nb_to_save = 256
+    nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device)
 
     with torch.autograd.no_grad():
-        records = []
+        record_c_quizzes, record_agreements = [], []
 
         last_log = -1
         start_time = time.perf_counter()
 
-        while bag_len(records) < wanted_nb:
+        while nb_c_quizzes_per_model.min() < wanted_nb:
             model = copy_for_inference(models[torch.randint(len(models), (1,)).item()])
             generator_id = model.id
 
@@ -1242,7 +1271,8 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
                 # to_keep = c_quiz_criterion_two_good(probas)
 
                 nb_disagreements = []
-                for model in models:
+                for i, model in enumerate(models):
+                    assert i == model.id  # a bit of paranoia
                     model = copy_for_inference(model)
                     nb_disagreements.append(
                         model_ae_argmax_nb_disagreements(model, c_quizzes).long()[
@@ -1252,15 +1282,18 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
                 nb_disagreements = torch.cat(nb_disagreements, dim=1)
 
                 v = nb_disagreements.sort(dim=1).values
-                to_keep = (v[:, 1] == 0) & (v[:, -1] > 3)
+                to_keep = (v[:, 2] == 0) & (v[:, -1] >= 4)
 
                 q = c_quizzes[to_keep]
 
                 if q.size(0) > 0:
-                    records.append(q)
+                    record_c_quizzes.append(q)
+                    a = (nb_disagreements == 0)[to_keep]
+                    record_agreements.append(a)
+                    nb_c_quizzes_per_model += a.long().sum(dim=0)
 
             duration = time.perf_counter() - start_time
-            nb_generated = bag_len(records)
+            nb_generated = nb_c_quizzes_per_model.min().item()
 
             if last_log < 0 or duration > last_log + 5:
                 last_log = duration
@@ -1276,17 +1309,33 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
                     e = "???"
 
                 log_string(
-                    f"nb_generated {bag_len(records)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
+                    f"nb_generated {bag_len(record_c_quizzes)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
                 )
 
         duration = time.perf_counter() - start_time
 
         log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
 
-        c_quizzes = torch.cat(records, dim=0).unique(dim=0)
+        c_quizzes = torch.cat(record_c_quizzes, dim=0)
+        agreements = torch.cat(record_agreements, dim=0)
 
         subset_c_quizzes = c_quizzes[:nb_to_save]
 
+        # #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+        # for model in models:
+        # model = copy_for_inference(model)
+        # prediction = model_ae_argmax_predictions(model, subset_c_quizzes)
+        # filename = f"prediction_c_quiz_{n_epoch:04d}_{model.id}.png"
+        # quiz_machine.problem.save_quizzes_as_image(
+        # args.result_dir,
+        # filename,
+        # quizzes=prediction,
+        # nrow=8,
+        # )
+        # log_string(f"wrote {filename}")
+        # exit(0)
+        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
         filename = f"culture_c_quiz_{n_epoch:04d}.png"
 
         # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record)
@@ -1314,7 +1363,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
 
         log_string(f"wrote {filename}")
 
-    return c_quizzes
+    return c_quizzes, agreements
 
 
 def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device):
@@ -1443,7 +1492,10 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         time_c_quizzes = int(time.perf_counter() - start_time)
 
-        c_quizzes = torch.cat([q.to(main_device) for q in records], dim=0)
+        c_quizzes = torch.cat([q.to(main_device) for q, _ in records], dim=0)
+        agreements = torch.cat([a.to(main_device) for _, a in records], dim=0)
+
+        print(f"DEBUG {c_quizzes.size()=} {agreements.size()=}")
 
         # --------------------------------------------------------------------
 
@@ -1474,11 +1526,15 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     for gpu, model in zip(gpus, weakest_models):
         log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
+        if c_quizzes is None:
+            c_quizzes_for_this_model = None
+        else:
+            c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
 
         t = threading.Thread(
             target=one_ae_epoch,
             daemon=True,
-            args=(model, quiz_machine, n_epoch, c_quizzes, gpu),
+            args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu),
         )
 
         threads.append(t)