From 64fdce6ab9ca5f9ec214782e24f4ccdb976336c7 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 17:03:33 +0200 Subject: [PATCH] Update. --- main.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index 704707d..182b907 100755 --- a/main.py +++ b/main.py @@ -114,9 +114,11 @@ parser.add_argument("--min_succeed_to_validate", type=int, default=2) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) -parser.add_argument("--prompt_noise", type=float, default=0.05) +parser.add_argument("--prompt_noise_proba", type=float, default=0.05) -parser.add_argument("--nb_hints", type=int, default=25) +parser.add_argument("--hint_proba", type=float, default=0.01) + +# parser.add_argument("--nb_hints", type=int, default=25) parser.add_argument("--nb_runs", type=int, default=1) @@ -358,23 +360,26 @@ def optimizer_to(optim, device): def add_hints(imt_set): input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] - h = torch.rand(masks.size(), device=masks.device) - masks - t = h.sort(dim=1).values[:, args.nb_hints, None] - mask_hints = (h < t).long() + # h = torch.rand(masks.size(), device=masks.device) - masks + # t = h.sort(dim=1).values[:, args.nb_hints, None] + # mask_hints = (h < t).long() + mask_hints = ( + torch.rand(input.size(), device=input.device) < args.hint_proba + ).long() * masks masks = (1 - mask_hints) * masks input = (1 - mask_hints) * input + mask_hints * targets return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) # Make pixels from the available input (mask=0) noise with probability -# args.prompt_noise +# args.prompt_noise_proba def add_noise(imt_set): input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] noise = quiz_machine.pure_noise(input.size(0), input.device) change = (1 - masks) * ( - torch.rand(input.size(), device=input.device) < args.prompt_noise + torch.rand(input.size(), device=input.device) < args.prompt_noise_proba ).long() input = (1 - change) * input + change * noise return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) @@ -403,14 +408,14 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"): record = [] - src = imt_set.split(args.train_batch_size) + src = imt_set.split(args.eval_batch_size) if desc is not None: src = tqdm.tqdm( src, dynamic_ncols=True, desc=desc, - total=imt_set.size(0) // args.train_batch_size, + total=imt_set.size(0) // args.eval_batch_size, ) for imt in src: @@ -502,9 +507,9 @@ def ae_generate(model, nb, local_device=main_device): sub_changed = all_changed[all_changed].clone() src = zip( - sub_input.split(args.train_batch_size), - sub_masks.split(args.train_batch_size), - sub_changed.split(args.train_batch_size), + sub_input.split(args.eval_batch_size), + sub_masks.split(args.eval_batch_size), + sub_changed.split(args.eval_batch_size), ) for input, masks, changed in src: @@ -549,17 +554,19 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): label = "train" model.train().to(local_device) optimizer_to(model.optimizer, local_device) + batch_size = args.train_batch_size else: label = "test" model.eval().to(local_device) + batch_size = args.eval_batch_size nb_samples, acc_loss = 0, 0.0 for imt in tqdm.tqdm( - imt_set.split(args.train_batch_size), + imt_set.split(batch_size), dynamic_ncols=True, desc=label, - total=quizzes.size(0) // args.train_batch_size, + total=quizzes.size(0) // batch_size, ): input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2] if train and nb_samples % args.batch_size == 0: @@ -716,7 +723,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): generator_id = model.id c_quizzes = ae_generate( - model=model, nb=args.train_batch_size * 10, local_device=local_device + model=model, nb=args.eval_batch_size * 10, local_device=local_device ) # Select the ones that are solved properly by some models and -- 2.39.5