Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 7 Sep 2024 07:20:27 +0000 (09:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 7 Sep 2024 07:20:27 +0000 (09:20 +0200)
main.py

diff --git a/main.py b/main.py
index b926f8e..264b5c7 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -991,9 +991,7 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
 
         hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
             1 - one_iteration_prediction
-        ) * sample_x_t_minus_1_given_x_0_x_t(
-            hat_x_0, x_t, max(1, args.nb_diffusion_iterations - it)
-        )
+        ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
 
         if hat_x_t_minus_1.equal(x_t):
             # log_string(f"exit after {it+1} iterations")
@@ -1035,11 +1033,11 @@ def model_ae_proba_solutions(model, input, log_proba=False):
         return (-loss).exp()
 
 
-def model_ae_argmax_nb_disagreements(model, input):
+def model_ae_argmax_nb_mistakes(model, input):
     record = []
 
     for x_0 in input.split(args.batch_size):
-        nb_disagreements = 0
+        nb_mistakes = 0
         for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
@@ -1050,11 +1048,11 @@ def model_ae_argmax_nb_disagreements(model, input):
 
             predicted = logits.argmax(dim=-1)
 
-            nb_disagreements = nb_disagreements + (
+            nb_mistakes = nb_mistakes + (
                 mask_generate * predicted != mask_generate * x_0
             ).long().sum(dim=1)
 
-        record.append(nb_disagreements)
+        record.append(nb_mistakes)
 
     return torch.cat(record, dim=0)
 
@@ -1275,40 +1273,37 @@ def save_badness_statistics(
 ######################################################################
 
 
-def c_quiz_criterion_one_good_one_bad(probas):
-    return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25)
-
-
-def c_quiz_criterion_one_good_no_very_bad(probas):
-    return (
-        (probas.max(dim=1).values >= 0.75)
-        & (probas.min(dim=1).values <= 0.75)
-        & (probas.min(dim=1).values >= 0.25)
-    )
-
-
-def c_quiz_criterion_diff(probas):
-    return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5
+def quiz_validation(models, c_quizzes, local_device):
+    nb_have_to_be_correct = args.nb_models // 2
+    nb_have_to_be_wrong = args.nb_models // 5
 
+    nb_runs = 3
+    nb_mistakes_to_be_wrong = 5
 
-def c_quiz_criterion_diff2(probas):
-    v = probas.sort(dim=1).values
-    return (v[:, -2] - v[:, 0]) >= 0.5
+    record_wrong = []
+    nb_correct, nb_wrong = 0, 0
 
+    for i, model in enumerate(models):
+        assert i == model.id  # a bit of paranoia
+        model = copy.deepcopy(model).to(local_device).eval()
+        correct, wrong = True, False
+        for _ in range(nb_runs):
+            n = model_ae_argmax_nb_mistakes(model, c_quizzes).long()
+            correct = correct & (n == 0)
+            wrong = wrong | (n >= nb_mistakes_to_be_wrong)
+        record_wrong.append(wrong[:, None])
+        nb_correct += correct.long()
+        nb_wrong += wrong.long()
 
-def c_quiz_criterion_few_good_one_bad(probas):
-    v = probas.sort(dim=1).values
-    return (v[:, 0] <= 0.25) & (v[:, -3] >= 0.5)
+    # print("nb_correct", nb_correct)
 
+    # print("nb_wrong", nb_wrong)
 
-def c_quiz_criterion_two_good(probas):
-    return ((probas >= 0.5).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.2)
+    to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
 
+    wrong = torch.cat(record_wrong, dim=1)
 
-def c_quiz_criterion_some(probas):
-    return ((probas >= 0.8).long().sum(dim=1) >= 1) & (
-        (probas <= 0.2).long().sum(dim=1) >= 1
-    )
+    return to_keep, wrong
 
 
 def generate_ae_c_quizzes(models, nb, local_device=main_device):
@@ -1346,33 +1341,12 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
             c_quizzes = c_quizzes[to_keep]
 
             if c_quizzes.size(0) > 0:
-                # p = [
-                # model_ae_proba_solutions(model, c_quizzes)[:, None]
-                # for model in models
-                # ]
-
-                # probas = torch.cat(p, dim=1)
-                # to_keep = c_quiz_criterion_two_good(probas)
-
-                nb_disagreements = []
-                for i, model in enumerate(models):
-                    assert i == model.id  # a bit of paranoia
-                    model = copy_for_inference(model)
-                    nb_disagreements.append(
-                        model_ae_argmax_nb_disagreements(model, c_quizzes).long()[
-                            :, None
-                        ]
-                    )
-                nb_disagreements = torch.cat(nb_disagreements, dim=1)
-
-                v = nb_disagreements.sort(dim=1).values
-                to_keep = (v[:, 2] == 0) & (v[:, -1] >= 4)
-
+                to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device)
                 q = c_quizzes[to_keep]
 
                 if q.size(0) > 0:
                     record_c_quizzes.append(q)
-                    a = (nb_disagreements == 0)[to_keep]
+                    a = (record_wrong == False)[to_keep]
                     record_agreements.append(a)
                     nb_c_quizzes_per_model += a.long().sum(dim=0)
 
@@ -1405,25 +1379,23 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
 
         subset_c_quizzes = c_quizzes[:nb_to_save]
 
-        # #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
         # for model in models:
-        # model = copy_for_inference(model)
-        # prediction = model_ae_argmax_predictions(model, subset_c_quizzes)
-        # filename = f"prediction_c_quiz_{n_epoch:04d}_{model.id}.png"
+        # for r in range(3):
+        # filename = f"culture_c_quiz_{n_epoch:04d}_prediction_{model.id}_{r}.png"
+        # p = model_ae_argmax_predictions(copy_for_inference(model), subset_c_quizzes)
         # quiz_machine.problem.save_quizzes_as_image(
         # args.result_dir,
         # filename,
-        # quizzes=prediction,
+        # quizzes=p,
+        # delta=True,
         # nrow=8,
         # )
         # log_string(f"wrote {filename}")
-        # exit(0)
-        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
 
         filename = f"culture_c_quiz_{n_epoch:04d}.png"
 
-        # c_quizzes, predicted_parts, correct_parts = bag_to_tensors(record)
-
         l = [
             model_ae_proba_solutions(copy_for_inference(model), subset_c_quizzes)
             for model in models
@@ -1438,8 +1410,6 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device):
             args.result_dir,
             filename,
             quizzes=subset_c_quizzes,
-            # predicted_parts=predicted_parts,
-            # correct_parts=correct_parts,
             comments=comments,
             delta=True,
             nrow=8,
@@ -1482,9 +1452,6 @@ if args.resume:
         log_string(f"successfully loaded {filename}")
         current_epoch = state["current_epoch"]
         c_quizzes = state["c_quizzes"]
-        # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
-        # total_time_training_models = state["total_time_training_models"]
-        # common_c_quiz_bags = state["common_c_quiz_bags"]
     except FileNotFoundError:
         log_string(f"cannot find {filename}")
         pass
@@ -1510,9 +1477,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     state = {
         "current_epoch": n_epoch,
         "c_quizzes": c_quizzes,
-        # "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
-        # "total_time_training_models": total_time_training_models,
-        # "common_c_quiz_bags": common_c_quiz_bags,
     }
 
     filename = "state.pth"
@@ -1526,28 +1490,27 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     # --------------------------------------------------------------------
 
-    # run_ae_test(
-    # model,
-    # alien_quiz_machine,
-    # n_epoch,
-    # c_quizzes=None,
-    # local_device=main_device,
-    # prefix="alien",
-    # )
-
-    # exit(0)
-
-    # one_ae_epoch(models[0], quiz_machine, n_epoch, None, main_device)
-    # exit(0)
-
     log_string(f"{time_train=} {time_c_quizzes=}")
 
     if (
         min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes
         and time_train >= time_c_quizzes
     ):
-        if c_quizzes is not None:
-            save_badness_statistics(last_n_epoch_c_quizzes, models, c_quizzes, "after")
+        if c_quizzes is None:
+            for model in models:
+                filename = f"ae_{model.id:03d}_naive.pth"
+                torch.save(
+                    {
+                        "state_dict": model.state_dict(),
+                        "optimizer_state_dict": model.optimizer.state_dict(),
+                        "test_accuracy": model.test_accuracy,
+                    },
+                    os.path.join(args.result_dir, filename),
+                )
+
+            log_string(f"wrote {filename}")
+
+        # --------------------------------------------------------------------
 
         last_n_epoch_c_quizzes = n_epoch
         nb_gpus = len(gpus)
@@ -1579,8 +1542,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         c_quizzes = torch.cat([q.to(main_device) for q, _ in records], dim=0)
         agreements = torch.cat([a.to(main_device) for _, a in records], dim=0)
 
-        print(f"DEBUG {c_quizzes.size()=} {agreements.size()=}")
-
         # --------------------------------------------------------------------
 
         log_string(f"generated_c_quizzes {c_quizzes.size()=}")
@@ -1589,8 +1550,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         for model in models:
             model.test_accuracy = 0
 
-        save_badness_statistics(n_epoch, models, c_quizzes, "before")
-
     if c_quizzes is None:
         log_string("no_c_quiz")
     else:
@@ -1603,9 +1562,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     threads = []
 
-    # for model in models:
-    # log_string(f"DEBUG {model.id} {sum([ p.sum() for p in model.parameters()]).item()}")
-
     start_time = time.perf_counter()
 
     for gpu, model in zip(gpus, weakest_models):
@@ -1639,10 +1595,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
                 "state_dict": model.state_dict(),
                 "optimizer_state_dict": model.optimizer.state_dict(),
                 "test_accuracy": model.test_accuracy,
-                # "gen_test_accuracy": model.gen_test_accuracy,
-                # "gen_state_dict": model.gen_state_dict,
-                # "train_c_quiz_bags": model.train_c_quiz_bags,
-                # "test_c_quiz_bags": model.test_c_quiz_bags,
             },
             os.path.join(args.result_dir, filename),
         )