From: François Fleuret Date: Wed, 4 Sep 2024 07:27:24 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ab1c19fdacabd6d0ab3e06bd0be37f8a25ac0b94;p=culture.git Update. --- diff --git a/main.py b/main.py index 61fc090..02c9fc6 100755 --- a/main.py +++ b/main.py @@ -957,27 +957,30 @@ def run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=main_ model.test_accuracy = nb_correct / nb_total - # for f, record in [("prediction", record_d), ("generation", record_nd)]: - # filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" + # Save some images - # result, predicted_parts, correct_parts = bag_to_tensors(record) + for f, record in [("prediction", record_d), ("generation", record_nd)]: + filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png" - # l = [model_ae_proba_solutions(model, result) for model in other_models] - # probas = torch.cat([x[:, None] for x in l], dim=1) - # comments = [] + result, predicted_parts, correct_parts = bag_to_tensors(record) - # for l in probas: - # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + # l = [model_ae_proba_solutions(model, result) for model in other_models] + # probas = torch.cat([x[:, None] for x in l], dim=1) + # comments = [] - # quiz_machine.problem.save_quizzes_as_image( - # args.result_dir, - # filename, - # quizzes=result, - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, - # comments=comments, - # ) - # log_string(f"wrote {filename}") + # for l in probas: + # comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result[:128], + predicted_parts=predicted_parts[:128], + correct_parts=correct_parts[:128], + # comments=comments, + ) + + log_string(f"wrote {filename}") # Prediction with functional perturbations @@ -1046,7 +1049,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - run_ae_test(model, quiz_machine, n_epoch, local_device=local_device) + run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device) ######################################################################