From 75c9766f18d6e79422437c147a4db83f95692fe4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 16:55:47 +0200 Subject: [PATCH] Update. --- main.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index a357687..704707d 100755 --- a/main.py +++ b/main.py @@ -50,9 +50,9 @@ parser.add_argument("--nb_epochs", type=int, default=10000) parser.add_argument("--batch_size", type=int, default=25) -parser.add_argument("--physical_batch_size", type=int, default=None) +parser.add_argument("--train_batch_size", type=int, default=None) -parser.add_argument("--inference_batch_size", type=int, default=25) +parser.add_argument("--eval_batch_size", type=int, default=25) parser.add_argument("--nb_train_samples", type=int, default=50000) @@ -273,10 +273,10 @@ else: assert len(gpus) == 0 main_device = torch.device("cpu") -if args.physical_batch_size is None: - args.physical_batch_size = args.batch_size +if args.train_batch_size is None: + args.train_batch_size = args.batch_size else: - assert args.batch_size % args.physical_batch_size == 0 + assert args.batch_size % args.train_batch_size == 0 assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 @@ -294,7 +294,7 @@ alien_problem = grids.Grids( alien_quiz_machine = quiz_machine.QuizMachine( problem=alien_problem, - batch_size=args.inference_batch_size, + batch_size=args.eval_batch_size, result_dir=args.result_dir, logger=log_string, device=main_device, @@ -315,7 +315,7 @@ if not args.resume: quiz_machine = quiz_machine.QuizMachine( problem=problem, - batch_size=args.inference_batch_size, + batch_size=args.eval_batch_size, result_dir=args.result_dir, logger=log_string, device=main_device, @@ -403,14 +403,14 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"): record = [] - src = imt_set.split(args.physical_batch_size) + src = imt_set.split(args.train_batch_size) if desc is not None: src = tqdm.tqdm( src, dynamic_ncols=True, desc=desc, - total=imt_set.size(0) // args.physical_batch_size, + total=imt_set.size(0) // args.train_batch_size, ) for imt in src: @@ -492,6 +492,8 @@ def ae_generate(model, nb, local_device=main_device): all_changed = torch.full((all_input.size(0),), True, device=all_input.device) for it in range(args.diffusion_nb_iterations): + log_string(f"nb_changed {all_changed.long().sum().item()}") + if not all_changed.any(): break @@ -500,9 +502,9 @@ def ae_generate(model, nb, local_device=main_device): sub_changed = all_changed[all_changed].clone() src = zip( - sub_input.split(args.physical_batch_size), - sub_masks.split(args.physical_batch_size), - sub_changed.split(args.physical_batch_size), + sub_input.split(args.train_batch_size), + sub_masks.split(args.train_batch_size), + sub_changed.split(args.train_batch_size), ) for input, masks, changed in src: @@ -554,10 +556,10 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): nb_samples, acc_loss = 0, 0.0 for imt in tqdm.tqdm( - imt_set.split(args.physical_batch_size), + imt_set.split(args.train_batch_size), dynamic_ncols=True, desc=label, - total=quizzes.size(0) // args.physical_batch_size, + total=quizzes.size(0) // args.train_batch_size, ): input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2] if train and nb_samples % args.batch_size == 0: @@ -714,7 +716,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): generator_id = model.id c_quizzes = ae_generate( - model=model, nb=args.physical_batch_size * 10, local_device=local_device + model=model, nb=args.train_batch_size * 10, local_device=local_device ) # Select the ones that are solved properly by some models and -- 2.39.5