From 8f17719e388a3800b8fd894cb407394320415a32 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 24 Aug 2024 16:16:14 +0200 Subject: [PATCH] Update. --- main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index d3d237e..e4bb494 100755 --- a/main.py +++ b/main.py @@ -970,6 +970,7 @@ def test_ae(local_device=main_device): targets, input = degrade_input( input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations ) + input_with_mask = NTC_channel_cat(input, mask_generate, rho) output = model(input_with_mask) loss = NTC_masked_cross_entropy(output, targets, mask_loss) @@ -1039,7 +1040,7 @@ def test_ae(local_device=main_device): f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" ) - filename = f"prediction_ae_{n_epoch:04d}_structure_{ns}.png" + filename = f"prediction_ae_{n_epoch:04d}_{ns}.png" quiz_machine.problem.save_quizzes_as_image( args.result_dir, @@ -1049,7 +1050,7 @@ def test_ae(local_device=main_device): correct_parts=correct_parts, ) - log_string(f"wrote {filename}") + log_string(f"wrote {filename}") if args.test == "ae": -- 2.39.5