From 1b4d0d66aaeb0b1b50a0c7bf2a84a781d0d0de8c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 2 Sep 2024 18:40:41 +0200 Subject: [PATCH] Update. --- main.py | 17 +++++++++++------ quiz_machine.py | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index b6aa328..4860073 100755 --- a/main.py +++ b/main.py @@ -1254,6 +1254,7 @@ def one_ae_epoch( model, other_models, quiz_machine, n_epoch, c_quizzes, local_device=main_device ): model.train().to(local_device) + optimizer_to(model.optimizer, local_device) nb_train_samples, acc_train_loss = 0, 0.0 @@ -1358,13 +1359,13 @@ def save_badness_statistics( n_epoch, models, c_quizzes, suffix=None, local_device=main_device ): for model in models: - models.eval().to(local_device) + model.eval().to(local_device) c_quizzes = c_quizzes.to(local_device) with torch.autograd.no_grad(): log_probas = sum( [model_ae_proba_solutions(model, c_quizzes) for model in models] ) - i = log_probas.sort().values + i = log_probas.sort().indices suffix = "" if suffix is None else "_" + suffix @@ -1373,14 +1374,16 @@ def save_badness_statistics( quiz_machine.problem.save_quizzes_as_image( args.result_dir, filename, - quizzes=quizzes[i[:128]], + quizzes=c_quizzes[i[:128]], # predicted_parts=predicted_parts, # correct_parts=correct_parts, - comments=comments, + # comments=comments, delta=True, nrow=8, ) + log_string(f"wrote {filename}") + def generate_ae_c_quizzes(models, local_device=main_device): criteria = [ @@ -1575,11 +1578,11 @@ for n_epoch in range(current_epoch, args.nb_epochs): log_string(f"{time_train=} {time_c_quizzes=}") if ( - min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes + min([float(m.test_accuracy) for m in models]) > args.accuracy_to_make_c_quizzes and time_train >= time_c_quizzes ): if c_quizzes is not None: - save_badness_statistics(models, c_quizzes) + save_badness_statistics(last_n_epoch_c_quizzes, models, c_quizzes, "after") last_n_epoch_c_quizzes = n_epoch start_time = time.perf_counter() @@ -1589,6 +1592,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): for model in models: model.test_accuracy = 0 + save_badness_statistics(n_epoch, models, c_quizzes, "before") + if c_quizzes is None: log_string("no_c_quiz") else: diff --git a/quiz_machine.py b/quiz_machine.py index af24c92..ce4d4f5 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -269,7 +269,7 @@ class QuizMachine: f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" ) - test_accuracy = nb_correct / nb_total + test_accuracy = (nb_correct / nb_total).item() ############################## -- 2.39.5