From c697b67def0600e2862313e93f69394fac9bf6ad Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 09:21:53 +0200 Subject: [PATCH] Update. --- main.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 05bb108..22854a9 100755 --- a/main.py +++ b/main.py @@ -589,7 +589,8 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True) one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False) - #!!!!!!!!!!!!!!!!!!!!!!!!! + # Save some original world quizzes and the full prediction (the four grids) + quizzes = quiz_machine.quiz_set(25, c_quizzes, args.c_quiz_multiplier).to( local_device ) @@ -600,11 +601,12 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result ) - #!!!!!!!!!!!!!!!!!!!!!!!!! - # predict + # Save some images of the prediction results (one grid at random) - quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier) + quizzes = quiz_machine.quiz_set( + args.nb_test_samples, c_quizzes, args.c_quiz_multiplier + ) imt_set = batch_prediction_imt(quizzes.to(local_device)) result = predict(model, imt_set, local_device=local_device).to("cpu") masks = imt_set[:, 1].to("cpu") @@ -623,6 +625,8 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): correct_parts=correct_parts[:128], ) + # Compute the test accuracy + nb_correct, nb_total = correct.sum(), quizzes.size(0) model.test_accuracy = nb_correct / nb_total @@ -630,7 +634,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy:.02f}%)" ) - # generate + # Save some images of the ex nihilo generation of the four grids result = generate(model, 150, local_device=local_device).to("cpu") quiz_machine.problem.save_quizzes_as_image( -- 2.39.5