From: François Fleuret Date: Wed, 25 Sep 2024 20:21:56 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0cc6f33e58c9c55509fd0d5cf9f4a487290ee2d9;p=culture.git Update. --- diff --git a/main.py b/main.py index 230453f..7af281c 100755 --- a/main.py +++ b/main.py @@ -99,7 +99,7 @@ parser.add_argument("--proba_plasticity", type=float, default=0.0) parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) -parser.add_argument("--proba_prompt_noise", type=float, default=0.05) +parser.add_argument("--proba_input_noise", type=float, default=0.05) parser.add_argument("--proba_hints", type=float, default=0.25) @@ -319,10 +319,10 @@ def generate_quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): ###################################################################### -def add_hints_imt(imt_set): - """Set every component of the mask to zero with probability - args.proba_hints, and for each component set to zero, copy the - corresponding value from the target into the input +def add_hints_imt(imt_set, proba_hints): + """Set every component of the mask to zero with probability proba, + and for each component set to zero, copy the corresponding value + from the target into the input """ input, masks, targets = imt_set.unbind(dim=1) @@ -330,7 +330,7 @@ def add_hints_imt(imt_set): # t = h.sort(dim=1).values[:, args.nb_hints, None] # mask_hints = (h < t).long() mask_hints = ( - torch.rand(input.size(), device=input.device) < args.proba_hints + torch.rand(input.size(), device=input.device) < proba_hints ).long() * masks masks = (1 - mask_hints) * masks @@ -338,13 +338,15 @@ def add_hints_imt(imt_set): return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) -def add_noise_imt(imt_set): - """Replace every component of the input by a random value with - probability args.proba_prompt_noise.""" +def add_input_noise_imt(imt_set, proba_input_noise): + """Replace every component of the non-masked input by a random + value with probability proba. + + """ input, masks, targets = imt_set.unbind(dim=1) noise = problem.pure_noise(input.size(0), input.device) change = (1 - masks) * ( - torch.rand(input.size(), device=input.device) < args.proba_prompt_noise + torch.rand(input.size(), device=input.device) < proba_input_noise ).long() input = (1 - change) * input + change * noise return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) @@ -393,7 +395,7 @@ def ae_predict(model, imt_set, local_device=main_device): def predict_the_four_grids( - model, input, with_noise=False, with_hints=False, local_device=main_device + model, input, proba_input_noise, proba_hints, local_device=main_device ): input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1)) nb = input.size(0) @@ -404,11 +406,11 @@ def predict_the_four_grids( input = (1 - masks) * targets imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1) - if with_hints: - imt_set = add_hints_imt(imt_set) + if proba_hints > 0: + imt_set = add_hints_imt(imt_set, proba_hints) - if with_noise: - imt_set = add_noise_imt(imt_set) + if proba_input_noise > 0: + imt_set = add_input_noise_imt(imt_set, proba_input_noise) result = ae_predict(model, imt_set, local_device=local_device) result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1) @@ -512,9 +514,9 @@ def one_epoch(model, n_epoch, c_quizzes, train=True, local_device=main_device): # complexity, and hints in half to allow dealing with hints when # validating c quizzes b_p = samples_for_prediction_imt(q_p) - b_p = add_noise_imt(b_p) + b_p = add_input_noise_imt(b_p, args.proba_input_noise) half = torch.rand(b_p.size(0)) < 0.5 - b_p[half] = add_hints_imt(b_p[half]) + b_p[half] = add_hints_imt(b_p[half], args.proba_hints) # The other half are denoising examples for the generation b_g = samples_for_generation_imt(q_g) @@ -661,8 +663,8 @@ def evaluate_quizzes(quizzes, models, with_hints, local_device): predicted = predict_the_four_grids( model=model, input=quizzes, - with_noise=False, - with_hints=with_hints, + proba_input_noise=0.0, + proba_hints=args.proba_hints, local_device=local_device, ) nb_mistakes = max_nb_mistakes_on_one_grid(quizzes, predicted) @@ -748,7 +750,6 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device): duration = time.perf_counter() - start_time - log_string(f"generate_c_quizz_speed {int(3600 * nb_validated / duration)}/h") log_string( f"validation_rate {nb_validated} / {nb_generated} ({nb_validated*100/nb_generated:.02e}%)" ) @@ -1039,7 +1040,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png" ) - log_string(f"generated_c_quizzes {new_c_quizzes.size()}") + # log_string(f"generated_c_quizzes {new_c_quizzes.size()}") train_c_quizzes = ( new_c_quizzes