From 4664060a00e406f21d8daa265986b0a418ecc737 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 15 Sep 2024 13:07:23 +0200 Subject: [PATCH] Update. --- main.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 49799e4..1461ab1 100755 --- a/main.py +++ b/main.py @@ -59,7 +59,7 @@ parser.add_argument("--physical_batch_size", type=int, default=None) parser.add_argument("--inference_batch_size", type=int, default=25) -parser.add_argument("--nb_train_samples", type=int, default=100000) +parser.add_argument("--nb_train_samples", type=int, default=50000) parser.add_argument("--nb_test_samples", type=int, default=1000) @@ -740,7 +740,7 @@ def quiz_validation( wrong = torch.cat(record_wrong, dim=1) - return to_keep, wrong + return to_keep, nb_correct, nb_wrong, wrong ###################################################################### @@ -782,7 +782,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): c_quizzes = c_quizzes[to_keep] if c_quizzes.size(0) > 0: - to_keep, record_wrong = quiz_validation( + to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation( models, c_quizzes, local_device, @@ -840,7 +840,7 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False) with torch.autograd.no_grad(): if solvable_only: - to_keep, _ = quiz_validation( + to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation( models, c_quizzes, main_device, @@ -850,16 +850,10 @@ def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False) ) c_quizzes = c_quizzes[to_keep] - for model in models: - model = copy.deepcopy(model).to(main_device).eval() - l.append(model_ae_proba_solutions(model, c_quizzes)) - - probas = torch.cat([x[:, None] for x in l], dim=1) - comments = [] - for l in probas: - comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + for c, w in zip(nb_correct, nb_wrong): + comments.append("nb_correct {c} nb_wrong {w}") quiz_machine.problem.save_quizzes_as_image( args.result_dir, -- 2.39.5