From 3ea461d38320c526a38022f05d69fa266c57afb5 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 11 Sep 2024 08:52:17 +0200 Subject: [PATCH] Update. --- main.py | 124 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 69 insertions(+), 55 deletions(-) diff --git a/main.py b/main.py index c1ef5bc..ed83a5c 100755 --- a/main.py +++ b/main.py @@ -111,20 +111,12 @@ parser.add_argument("--diffusion_epsilon", type=float, default=0.05) parser.add_argument("--min_succeed_to_validate", type=int, default=2) -parser.add_argument("--max_fail_to_validate", type=int, default=3) - parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) -parser.add_argument("--proba_understands", type=float, default=0.95) - -parser.add_argument("--proba_not_understands", type=float, default=0.1) - -parser.add_argument("--temperature_hot", type=float, default=1.5) - -parser.add_argument("--temperature_cold", type=float, default=1) - parser.add_argument("--prompt_noise", type=float, default=0.05) +parser.add_argument("--nb_hints", type=int, default=5) + parser.add_argument("--dirty_debug", action="store_true", default=False) parser.add_argument("--test", type=str, default=None) @@ -705,19 +697,6 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise x_t = (1 - mask_generate) * x_0 + mask_generate * x_t - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - # filename = f"debug.png" - - # quiz_machine.problem.save_quizzes_as_image( - # args.result_dir, - # filename, - # quizzes=x_t, - # ) - - # log_string(f"wrote {filename}") - # exit(0) - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - # We may inject noise to prevent high-complexity non-structure # signal to be generated as a way of "increasing reasoning # complexity" @@ -741,10 +720,14 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise ###################################################################### -def ae_generate(model, x_0, mask_generate, nb_iterations_max=50): +def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None): noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device) - x_t = (1 - mask_generate) * x_0 + mask_generate * noise + if mask_hints is None: + x_t = (1 - mask_generate) * x_0 + mask_generate * noise + else: + mask = mask_generate * (1 - mask_hints) + x_t = (1 - mask) * x_0 + mask * noise one_iteration_prediction = deterministic(mask_generate)[:, None] @@ -925,7 +908,7 @@ def run_ae_test( # Save some images - if n_epoch < 50: + if n_epoch < 100: for f, record in [("prediction", record_d), ("generation", record_nd)]: result, predicted_parts, correct_parts = bag_to_tensors(record) @@ -1062,28 +1045,52 @@ def save_badness_statistics( ###################################################################### -def quiz_validation(models, c_quizzes, local_device): - nb_have_to_be_correct = 3 - nb_have_to_be_wrong = 1 - nb_mistakes_to_be_wrong = 5 - +def quiz_validation( + models, + c_quizzes, + local_device, + nb_have_to_be_correct=3, + nb_have_to_be_not_correct=0, + nb_have_to_be_wrong=1, + nb_mistakes_to_be_wrong=5, + nb_hints=0, + nb_runs=1, +): record_wrong = [] nb_correct, nb_wrong = 0, 0 for i, model in enumerate(models): assert i == model.id # a bit of paranoia model = copy.deepcopy(model).to(local_device).eval() - correct, wrong = True, False - for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: mask_generate = quiz_machine.make_quiz_mask( - quizzes=c_quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad + quizzes=c_quizzes, + quad_order=("A", "f_A", "B", "f_B"), + quad_mask=quad, ) - result = ae_generate(model, (1 - mask_generate) * c_quizzes, mask_generate) - nb_mistakes = (result != c_quizzes).long().sum(dim=1) - correct = correct & (nb_mistakes == 0) - wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong) + for _ in range(nb_runs): + if nb_hints == 0: + mask_hints = None + else: + u = ( + torch.rand(mask_generate.size(), device=mask_generate.device) + * mask_generate + ) + mask_hints = ( + u > u.sort(dim=1, descending=True).values[:, nb_hints, None] + ).long() + + result = ae_generate( + model=model, + x_0=(1 - mask_generate) * c_quizzes, + mask_generate=mask_generate, + mask_hints=mask_hints, + ) + + nb_mistakes = (result != c_quizzes).long().sum(dim=1) + correct = correct & (nb_mistakes == 0) + wrong = wrong | (nb_mistakes >= nb_mistakes_to_be_wrong) record_wrong.append(wrong[:, None]) nb_correct += correct.long() @@ -1131,25 +1138,13 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): c_quizzes = ae_generate(model, template, mask_generate) - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - ## for quad in [(0, 1, 0, 0), (0, 0, 0, 1)]: - ## mask_generate = quiz_machine.make_quiz_mask( - ## quizzes=c_quizzes, - ## quad_order=("A", "f_A", "B", "f_B"), - ## quad_mask=quad, - ## ) - ## c_quizzes = ae_generate( - ## model, - ## (1 - mask_generate) * c_quizzes, - ## mask_generate, - ## ) - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - to_keep = quiz_machine.problem.trivial(c_quizzes) == False c_quizzes = c_quizzes[to_keep] if c_quizzes.size(0) > 0: - to_keep, record_wrong = quiz_validation(models, c_quizzes, local_device) + to_keep, record_wrong = quiz_validation( + models, c_quizzes, local_device, nb_hints=args.nb_hints + ) q = c_quizzes[to_keep] if q.size(0) > 0: @@ -1195,9 +1190,24 @@ def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): ###################################################################### -def save_c_quizzes_with_scores(models, c_quizzes, filename): +def save_c_quizzes_with_scores(models, c_quizzes, nb, filename, solvable_only=False): l = [] + if solvable_only: + to_keep, _ = quiz_validation( + models, + c_quizzes, + main_device, + nb_have_to_be_correct=1, + nb_have_to_be_wrong=0, + nb_hints=0, + ) + c_quizzes = c_quizzes[to_keep] + + c_quizzes = c_quizzes[ + torch.randperm(c_quizzes.size(0), device=c_quizzes.device)[:nb] + ] + with torch.autograd.no_grad(): for model in models: model = copy.deepcopy(model).to(main_device).eval() @@ -1389,7 +1399,11 @@ for n_epoch in range(current_epoch, args.nb_epochs): # -------------------------------------------------------------------- filename = f"culture_c_quiz_{n_epoch:04d}.png" - save_c_quizzes_with_scores(models, c_quizzes[:128], filename) + save_c_quizzes_with_scores( + models, c_quizzes, 256, filename, solvable_only=False + ) + filename = f"culture_c_quiz_{n_epoch:04d}_solvable.png" + save_c_quizzes_with_scores(models, c_quizzes, 256, filename, solvable_only=True) log_string(f"generated_c_quizzes {c_quizzes.size()=}") -- 2.39.5