From a946475528da2b86a1b3d0cd3913ac73d3183ce4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 16 Sep 2024 22:49:37 +0200 Subject: [PATCH] Update. --- main.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 649c889..70ca672 100755 --- a/main.py +++ b/main.py @@ -658,11 +658,9 @@ def batch_prediction(input, proba_hints=0.0): return input, targets, mask_generate -def predict(model, quizzes, local_device=main_device): +def predict(model, input, targets, mask, local_device=main_device): model.eval().to(local_device) - input, targets, mask = batch_prediction(quizzes.to(local_device)) - input_batches = input.reshape(-1, args.physical_batch_size, input.size(1)) targets_batches = targets.reshape(-1, args.physical_batch_size, targets.size(1)) mask_batches = mask.reshape(-1, args.physical_batch_size, mask.size(1)) @@ -673,7 +671,7 @@ def predict(model, quizzes, local_device=main_device): zip(input_batches, targets_batches, mask_batches), dynamic_ncols=True, desc="predict", - total=quizzes.size(0) // args.physical_batch_size, + total=input.size(0) // args.physical_batch_size, ): # noise = quiz_machine.problem.pure_noise(input.size(0), input.device) input = (1 - mask) * input # + mask * noise @@ -806,20 +804,24 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device): one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True) - one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False) quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier) - result = predict(model, quizzes).to("cpu") - + input, targets, mask = batch_prediction(quizzes.to(local_device)) + result = predict(model, input, targets, mask).to("cpu") + mask = mask.to("cpu") + correct = (quizzes == result).min(dim=1).values.long() + correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[ + :, :, 1 + ] quiz_machine.problem.save_quizzes_as_image( args.result_dir, f"culture_prediction_{n_epoch}_{model.id}.png", quizzes=result[:128], + correct_parts=correct_parts[:128], ) - nb_correct = (quizzes == result).min(dim=1).values.long().sum() - model.test_accuracy = nb_correct / quizzes.size(0) + model.test_accuracy = correct.sum() / quizzes.size(0) ###################################################################### -- 2.39.5