Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 17:20:47 +0000 (19:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 1 Sep 2024 17:20:47 +0000 (19:20 +0200)
main.py

diff --git a/main.py b/main.py
index b87518e..9801702 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -51,7 +51,7 @@ parser.add_argument("--batch_size", type=int, default=25)
 
 parser.add_argument("--physical_batch_size", type=int, default=None)
 
-parser.add_argument("--inference_batch_size", type=int, default=50)
+parser.add_argument("--inference_batch_size", type=int, default=25)
 
 parser.add_argument("--nb_train_samples", type=int, default=40000)
 
@@ -1282,7 +1282,7 @@ def one_ae_epoch(
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
     )
 
-    run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device)
+    run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device)
 
 
 ######################################################################
@@ -1360,11 +1360,8 @@ def generate_ae_c_quizzes(models, local_device=main_device):
 
     duration_max = 4 * 3600
 
-    wanted_nb = 128
-    nb_to_save = 128
-
-    # wanted_nb = args.nb_train_samples // args.c_quiz_multiplier
-    # nb_to_save = 256
+    wanted_nb = args.nb_train_samples // args.c_quiz_multiplier
+    nb_to_save = 256
 
     with torch.autograd.no_grad():
         records = [[] for _ in criteria]
@@ -1524,6 +1521,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
     # exit(0)
 
+    log_string(f"{time_train=} {time_c_quizzes=}")
+
     if (
         n_epoch >= 200
         and min([m.test_accuracy for m in models]) > args.accuracy_to_make_c_quizzes