From 039523af4eb069f3a2293d5266a1cac4867567f4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 1 Sep 2024 19:20:47 +0200 Subject: [PATCH] Update. --- main.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index b87518e..9801702 100755 --- a/main.py +++ b/main.py @@ -51,7 +51,7 @@ parser.add_argument("--batch_size", type=int, default=25) parser.add_argument("--physical_batch_size", type=int, default=None) -parser.add_argument("--inference_batch_size", type=int, default=50) +parser.add_argument("--inference_batch_size", type=int, default=25) parser.add_argument("--nb_train_samples", type=int, default=40000) @@ -1282,7 +1282,7 @@ def one_ae_epoch( f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - # run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device) + run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device) ###################################################################### @@ -1360,11 +1360,8 @@ def generate_ae_c_quizzes(models, local_device=main_device): duration_max = 4 * 3600 - wanted_nb = 128 - nb_to_save = 128 - - # wanted_nb = args.nb_train_samples // args.c_quiz_multiplier - # nb_to_save = 256 + wanted_nb = args.nb_train_samples // args.c_quiz_multiplier + nb_to_save = 256 with torch.autograd.no_grad(): records = [[] for _ in criteria] @@ -1524,6 +1521,8 @@ for n_epoch in range(current_epoch, args.nb_epochs): # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device) # exit(0) + log_string(f"{time_train=} {time_c_quizzes=}") + if ( n_epoch >= 200 and min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes -- 2.39.5