From ab1c19fdacabd6d0ab3e06bd0be37f8a25ac0b94 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 4 Sep 2024 09:27:24 +0200 Subject: [PATCH] Update. --- main.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 61fc090..02c9fc6 100755 --- a/main.py +++ b/main.py @@ -957,27 +957,30 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_ model.test_accuracy = nb_correct / nb_total - # for f, record in [("prediction", record_d), ("generation", record_nd)]: - # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" + # Save some images - # result, predicted_parts, correct_parts = bag_to_tensors(record) + for f, record in [("prediction", record_d), ("generation", record_nd)]: + filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" - # l = [model_ae_proba_solutions(model, result) for model in other_models] - # probas = torch.cat([x[:, None] for x in l], dim=1) - # comments = [] + result, predicted_parts, correct_parts = bag_to_tensors(record) - # for l in probas: - # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + # l = [model_ae_proba_solutions(model, result) for model in other_models] + # probas = torch.cat([x[:, None] for x in l], dim=1) + # comments = [] - # quiz_machine.problem.save_quizzes_as_image( - # args.result_dir, - # filename, - # quizzes=result, - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, - # comments=comments, - # ) - # log_string(f"wrote {filename}") + # for l in probas: + # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result[:128], + predicted_parts=predicted_parts[:128], + correct_parts=correct_parts[:128], + # comments=comments, + ) + + log_string(f"wrote {filename}") # Prediction with functional perturbations @@ -1046,7 +1049,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - run_ae_test(model, quiz_machine, n_epoch, local_device=local_device) + run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device) ###################################################################### -- 2.39.5