From 2192aea5279309636ca9118aead311f7cebd29ad Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 17:05:20 +0200 Subject: [PATCH] Update. --- main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 182b907..4a44fd3 100755 --- a/main.py +++ b/main.py @@ -114,9 +114,9 @@ parser.add_argument("--min_succeed_to_validate", type=int, default=2) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) -parser.add_argument("--prompt_noise_proba", type=float, default=0.05) +parser.add_argument("--proba_prompt_noise", type=float, default=0.05) -parser.add_argument("--hint_proba", type=float, default=0.01) +parser.add_argument("--proba_hint", type=float, default=0.01) # parser.add_argument("--nb_hints", type=int, default=25) @@ -364,7 +364,7 @@ def add_hints(imt_set): # t = h.sort(dim=1).values[:, args.nb_hints, None] # mask_hints = (h < t).long() mask_hints = ( - torch.rand(input.size(), device=input.device) < args.hint_proba + torch.rand(input.size(), device=input.device) < args.proba_hint ).long() * masks masks = (1 - mask_hints) * masks input = (1 - mask_hints) * input + mask_hints * targets @@ -372,14 +372,14 @@ def add_hints(imt_set): # Make pixels from the available input (mask=0) noise with probability -# args.prompt_noise_proba +# args.proba_prompt_noise def add_noise(imt_set): input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] noise = quiz_machine.pure_noise(input.size(0), input.device) change = (1 - masks) * ( - torch.rand(input.size(), device=input.device) < args.prompt_noise_proba + torch.rand(input.size(), device=input.device) < args.proba_prompt_noise ).long() input = (1 - change) * input + change * noise return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) -- 2.39.5