Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 2 Sep 2024 16:40:41 +0000 (18:40 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 2 Sep 2024 16:40:41 +0000 (18:40 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index b6aa328..4860073 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -1254,6 +1254,7 @@ def one_ae_epoch(
     model, other_models, quiz_machine, n_epoch, c_quizzes, local_device=main_device
 ):
     model.train().to(local_device)
+    optimizer_to(model.optimizer, local_device)
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
@@ -1358,13 +1359,13 @@ def save_badness_statistics(
     n_epoch, models, c_quizzes, suffix=None, local_device=main_device
 ):
     for model in models:
-        models.eval().to(local_device)
+        model.eval().to(local_device)
     c_quizzes = c_quizzes.to(local_device)
     with torch.autograd.no_grad():
         log_probas = sum(
             [model_ae_proba_solutions(model, c_quizzes) for model in models]
         )
-        i = log_probas.sort().values
+        i = log_probas.sort().indices
 
     suffix = "" if suffix is None else "_" + suffix
 
@@ -1373,14 +1374,16 @@ def save_badness_statistics(
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
         filename,
-        quizzes=quizzes[i[:128]],
+        quizzes=c_quizzes[i[:128]],
         # predicted_parts=predicted_parts,
         # correct_parts=correct_parts,
-        comments=comments,
+        comments=comments,
         delta=True,
         nrow=8,
     )
 
+    log_string(f"wrote {filename}")
+
 
 def generate_ae_c_quizzes(models, local_device=main_device):
     criteria = [
@@ -1575,11 +1578,11 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     log_string(f"{time_train=} {time_c_quizzes=}")
 
     if (
-        min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes
+        min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes
         and time_train >= time_c_quizzes
     ):
         if c_quizzes is not None:
-            save_badness_statistics(models, c_quizzes)
+            save_badness_statistics(last_n_epoch_c_quizzes, models, c_quizzes, "after")
 
         last_n_epoch_c_quizzes = n_epoch
         start_time = time.perf_counter()
@@ -1589,6 +1592,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         for model in models:
             model.test_accuracy = 0
 
+        save_badness_statistics(n_epoch, models, c_quizzes, "before")
+
     if c_quizzes is None:
         log_string("no_c_quiz")
     else:
index af24c92..ce4d4f5 100755 (executable)
@@ -269,7 +269,7 @@ class QuizMachine:
             f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
         )
 
-        test_accuracy = nb_correct / nb_total
+        test_accuracy = (nb_correct / nb_total).item()
 
         ##############################