Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 2 Sep 2024 15:26:39 +0000 (17:26 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 2 Sep 2024 15:26:39 +0000 (17:26 +0200)
main.py

diff --git a/main.py b/main.py
index b48d2a8..b6aa328 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1070,7 +1070,7 @@ def ae_generate(model, input, mask_generate, nb_iterations_max=50):
 ######################################################################
 
 
-def model_ae_proba_solutions(model, input):
+def model_ae_proba_solutions(model, input, log_proba=False):
     record = []
 
     for q in input.split(args.batch_size):
@@ -1089,7 +1089,10 @@ def model_ae_proba_solutions(model, input):
 
     loss = torch.cat(record, dim=0)
 
-    return (-loss).exp()
+    if log_proba:
+        return -loss
+    else:
+        return (-loss).exp()
 
 
 nb_diffusion_iterations = 25
@@ -1351,6 +1354,34 @@ def c_quiz_criterion_some(probas):
     )
 
 
+def save_badness_statistics(
+    n_epoch, models, c_quizzes, suffix=None, local_device=main_device
+):
+    for model in models:
+        models.eval().to(local_device)
+    c_quizzes = c_quizzes.to(local_device)
+    with torch.autograd.no_grad():
+        log_probas = sum(
+            [model_ae_proba_solutions(model, c_quizzes) for model in models]
+        )
+        i = log_probas.sort().values
+
+    suffix = "" if suffix is None else "_" + suffix
+
+    filename = f"culture_badness_{n_epoch:04d}{suffix}.png"
+
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir,
+        filename,
+        quizzes=quizzes[i[:128]],
+        # predicted_parts=predicted_parts,
+        # correct_parts=correct_parts,
+        comments=comments,
+        delta=True,
+        nrow=8,
+    )
+
+
 def generate_ae_c_quizzes(models, local_device=main_device):
     criteria = [
         # c_quiz_criterion_only_one,
@@ -1493,6 +1524,7 @@ if args.resume:
         state = torch.load(os.path.join(args.result_dir, filename))
         log_string(f"successfully loaded {filename}")
         current_epoch = state["current_epoch"]
+        c_quizzes = state["c_quizzes"]
         # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
         # total_time_training_models = state["total_time_training_models"]
         # common_c_quiz_bags = state["common_c_quiz_bags"]
@@ -1520,10 +1552,12 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     state = {
         "current_epoch": n_epoch,
+        "c_quizzes": c_quizzes,
         # "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
         # "total_time_training_models": total_time_training_models,
         # "common_c_quiz_bags": common_c_quiz_bags,
     }
+
     filename = "state.pth"
     torch.save(state, os.path.join(args.result_dir, filename))
     log_string(f"wrote {filename}")
@@ -1541,10 +1575,12 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     log_string(f"{time_train=} {time_c_quizzes=}")
 
     if (
-        n_epoch >= 200
-        and min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes
+        min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes
         and time_train >= time_c_quizzes
     ):
+        if c_quizzes is not None:
+            save_badness_statistics(models, c_quizzes)
+
         last_n_epoch_c_quizzes = n_epoch
         start_time = time.perf_counter()
         c_quizzes = generate_ae_c_quizzes(models, local_device=main_device)