Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 14:55:47 +0000 (16:55 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 14:55:47 +0000 (16:55 +0200)
main.py

diff --git a/main.py b/main.py
index a357687..704707d 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -50,9 +50,9 @@ parser.add_argument("--nb_epochs", type=int, default=10000)
 
 parser.add_argument("--batch_size", type=int, default=25)
 
-parser.add_argument("--physical_batch_size", type=int, default=None)
+parser.add_argument("--train_batch_size", type=int, default=None)
 
-parser.add_argument("--inference_batch_size", type=int, default=25)
+parser.add_argument("--eval_batch_size", type=int, default=25)
 
 parser.add_argument("--nb_train_samples", type=int, default=50000)
 
@@ -273,10 +273,10 @@ else:
     assert len(gpus) == 0
     main_device = torch.device("cpu")
 
-if args.physical_batch_size is None:
-    args.physical_batch_size = args.batch_size
+if args.train_batch_size is None:
+    args.train_batch_size = args.batch_size
 else:
-    assert args.batch_size % args.physical_batch_size == 0
+    assert args.batch_size % args.train_batch_size == 0
 
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
@@ -294,7 +294,7 @@ alien_problem = grids.Grids(
 
 alien_quiz_machine = quiz_machine.QuizMachine(
     problem=alien_problem,
-    batch_size=args.inference_batch_size,
+    batch_size=args.eval_batch_size,
     result_dir=args.result_dir,
     logger=log_string,
     device=main_device,
@@ -315,7 +315,7 @@ if not args.resume:
 
 quiz_machine = quiz_machine.QuizMachine(
     problem=problem,
-    batch_size=args.inference_batch_size,
+    batch_size=args.eval_batch_size,
     result_dir=args.result_dir,
     logger=log_string,
     device=main_device,
@@ -403,14 +403,14 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
 
     record = []
 
-    src = imt_set.split(args.physical_batch_size)
+    src = imt_set.split(args.train_batch_size)
 
     if desc is not None:
         src = tqdm.tqdm(
             src,
             dynamic_ncols=True,
             desc=desc,
-            total=imt_set.size(0) // args.physical_batch_size,
+            total=imt_set.size(0) // args.train_batch_size,
         )
 
     for imt in src:
@@ -492,6 +492,8 @@ def ae_generate(model, nb, local_device=main_device):
     all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
 
     for it in range(args.diffusion_nb_iterations):
+        log_string(f"nb_changed {all_changed.long().sum().item()}")
+
         if not all_changed.any():
             break
 
@@ -500,9 +502,9 @@ def ae_generate(model, nb, local_device=main_device):
         sub_changed = all_changed[all_changed].clone()
 
         src = zip(
-            sub_input.split(args.physical_batch_size),
-            sub_masks.split(args.physical_batch_size),
-            sub_changed.split(args.physical_batch_size),
+            sub_input.split(args.train_batch_size),
+            sub_masks.split(args.train_batch_size),
+            sub_changed.split(args.train_batch_size),
         )
 
         for input, masks, changed in src:
@@ -554,10 +556,10 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
     nb_samples, acc_loss = 0, 0.0
 
     for imt in tqdm.tqdm(
-        imt_set.split(args.physical_batch_size),
+        imt_set.split(args.train_batch_size),
         dynamic_ncols=True,
         desc=label,
-        total=quizzes.size(0) // args.physical_batch_size,
+        total=quizzes.size(0) // args.train_batch_size,
     ):
         input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2]
         if train and nb_samples % args.batch_size == 0:
@@ -714,7 +716,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
         generator_id = model.id
 
         c_quizzes = ae_generate(
-            model=model, nb=args.physical_batch_size * 10, local_device=local_device
+            model=model, nb=args.train_batch_size * 10, local_device=local_device
         )
 
         # Select the ones that are solved properly by some models and