From 84b4312c224c75a6bd8355286d634c52ace22600 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 3 Sep 2024 20:20:30 +0200 Subject: [PATCH] Update. --- main.py | 287 +++----------------------------------------------------- 1 file changed, 14 insertions(+), 273 deletions(-) diff --git a/main.py b/main.py index 6113813..61fc090 100755 --- a/main.py +++ b/main.py @@ -68,6 +68,7 @@ parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--reboot", action="store_true", default=False) # ---------------------------------- + parser.add_argument("--model", type=str, default="37M") parser.add_argument("--dim_model", type=int, default=None) @@ -83,6 +84,7 @@ parser.add_argument("--nb_blocks", type=int, default=None) parser.add_argument("--dropout", type=float, default=0.5) # ---------------------------------- + parser.add_argument("--deterministic_synthesis", action="store_true", default=False) parser.add_argument("--problem", type=str, default="grids") @@ -103,7 +105,7 @@ parser.add_argument("--min_succeed_to_validate", type=int, default=2) parser.add_argument("--max_fail_to_validate", type=int, default=3) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98) parser.add_argument("--proba_understands", type=float, default=0.95) @@ -119,8 +121,6 @@ parser.add_argument("--dirty_debug", action="store_true", default=False) parser.add_argument("--test", type=str, default=None) -parser.add_argument("--logit_std_max", type=float, default=-1) - ###################################################################### grids_tasks = ", ".join( @@ -341,275 +341,6 @@ def optimizer_to(optim, device): subparam._grad.data = subparam._grad.data.to(device) -###################################################################### - - -def run_tests(model, quiz_machine, local_device=main_device): - with torch.autograd.no_grad(): - model.to(local_device).eval() - - nb_test_samples, acc_test_loss = 0, 0.0 - nb_samples_accumulated = 0 - - full_input, _, full_mask_loss = quiz_machine.data_input( - args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier - ) - src = zip( - full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) - ) - - for input, mask_loss in tqdm.tqdm( - src, - dynamic_ncols=True, - desc="test", - total=full_input.size(0) // args.batch_size, - ): - input = input.to(local_device) - mask_loss = mask_loss.to(local_device) - targets = input - - output = model(input) - loss_per_token = F.cross_entropy( - output.transpose(1, 2), targets, reduction="none" - ) - loss = (loss_per_token * mask_loss).mean() - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) - - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - - log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") - - input, _, _ = quiz_machine.data_input( - 2000, model.test_c_quiz_bags, args.c_quiz_multiplier - ) - - model.test_accuracy = quiz_machine.produce_results( - n_epoch=n_epoch, - model=model, - input=input, - result_dir=args.result_dir, - ) - - -###################################################################### - - -def one_epoch(model, quiz_machine, local_device=main_device): - model.to(local_device).train() - optimizer_to(model.optimizer, local_device) - - nb_train_samples, acc_train_loss = 0, 0.0 - - full_input, _, full_mask_loss = quiz_machine.data_input( - args.nb_train_samples, - model.train_c_quiz_bags + common_c_quiz_bags, - args.c_quiz_multiplier, - ) - src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) - - for input, mask_loss in tqdm.tqdm( - src, - dynamic_ncols=True, - desc="training", - total=full_input.size(0) // args.batch_size, - ): - input = input.to(local_device) - mask_loss = mask_loss.to(local_device) - - if nb_train_samples % args.batch_size == 0: - model.optimizer.zero_grad() - - targets = input - output = model(input) - loss = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") - loss = (loss * mask_loss).mean() + model.loss - - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) - - loss.backward() - - if nb_train_samples % args.batch_size == 0: - model.optimizer.step() - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - - log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}") - - run_tests(model, quiz_machine) - - model.to(main_device) - optimizer_to(model.optimizer, main_device) - - -###################################################################### - - -def model_modifier_hot(model): - model.temperature = args.temperature_hot - # model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2)) - - -def model_modifier_cold(model): - model.temperature = args.temperature_cold - # pass - - -c_quizzes_procedure = [ - (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold), - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), - # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold), -] - -# quad_order, quad_generate, quad_noise, quad_loss - -data_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), - (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), - (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)), - (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)), - (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), -] - -###################################################################### - - -def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): - nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models) - nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate - - start_time = time.perf_counter() - - for model in models: - model.recorded_c_quizzes = [] - - teaching_count = torch.zeros(len(models), len(models), dtype=torch.int64) - - while nb_validated < nb_to_validate: - model_for_generation = models[torch.randint(len(models), (1,)).item()] - - # We generate quizzes with a procedure that injects some - # structured noise - - c_quizzes = quiz_machine.generate_c_quizzes( - nb_to_generate_per_iteration, - model_for_generation=model, - procedure=c_quizzes_procedure, - ) - - nb_generated += c_quizzes.size(0) - - # We discard the trivial ones, according to a criterion - # specific to the world quizzes (e.g. B=f(B)) - - to_keep = quiz_machine.problem.trivial(c_quizzes) == False - - c_quizzes = c_quizzes[to_keep] - - # Compute the responses of all the models on the c_quizzes, - # and their proba estimates of their responses - - solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone() - - proba_own_solution = torch.zeros( - c_quizzes.size(0), len(models), device=solved_c_quizzes.device - ) - - for model in models: - (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict( - model, - solved_c_quizzes[:, model.id], - quad_orders=("A", "f_A", "B", "f_B"), - quad=(0, 0, 0, 1), - ) - - proba_own_solution[:, model.id] = model_proba_solutions( - model, solved_c_quizzes[:, model.id] - ) - - # Now for every model not confident of its response, we pick - # the most consistent from a model which is confident - - for s in range(proba_own_solution.size(0)): - # At least one GPT does not understand at all - if proba_own_solution[s, :].min() < args.proba_not_understands: - dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands - nb_fails = dont_get_this_quiz.long().sum() - # At most max_fail_to_validate do not understand (default 3/5) - if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate: - for model in models: - # If a GPT does not get that quiz - if dont_get_this_quiz[model.id]: - assert ( - proba_own_solution[s, model.id] < args.proba_understands - ) - # Look at its estimate of the others'solutions - proba_other_solutions = model_proba_solutions( - model, solved_c_quizzes[s] - ) - # Randomize a bit the orders for the frequent P=1 - proba_other_solutions += ( - torch.rand(proba_other_solutions.size()) * 1e-6 - ) - # Remove the under threshold confidence solutions - proba_other_solutions[dont_get_this_quiz] = -1 - i = proba_other_solutions.argmax() - model.recorded_c_quizzes.append(solved_c_quizzes[s, i]) - teaching_count[i, model.id] += 1 - nb_validated += 1 - - duration = time.perf_counter() - start_time - - if nb_validated > 0: - if nb_validated < nb_to_validate: - d = (nb_to_validate - nb_validated) * duration / nb_validated - e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime( - "%a %H:%M" - ) - else: - e = "now!" - else: - e = "???" - - log_string( - f"keep c_quizzes model {model_for_generation.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%" - ) - - for s in range(teaching_count.size(0)): - o = [x.item() for x in teaching_count[s]] - log_string(f"teacher model {s} to {o}") - - for model in models: - new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0) - - if new_bag.size(0) > 0: - n = (new_bag.size(0) * nb_for_train) // (nb_for_train + nb_for_test) - if n > 0: - model.train_c_quiz_bags.append(new_bag[:n]) - if n < new_bag.size(0): - model.test_c_quiz_bags.append(new_bag[n:]) - - c_quizzes = new_bag[:128] - - l = [model_proba_solutions(model, c_quizzes) for model in models] - probas = torch.cat([x[:, None] for x in l], dim=1) - comments = [] - - for l in probas: - comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) - - filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, filename, c_quizzes, comments=comments - ) - - log_string( - f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}" - ) - - ###################################################################### from mygpt import ( @@ -959,6 +690,16 @@ class FunctionalAE(nn.Module): ###################################################################### +# quad_order, quad_generate, quad_noise, quad_loss + +data_structures = [ + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), + (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)), + (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)), + (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), +] + def ae_batches( quiz_machine, @@ -1305,7 +1046,7 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}" ) - run_ae_test(model, quiz_machine, n_epoch, c_quizzes, local_device=local_device) + run_ae_test(model, quiz_machine, n_epoch, local_device=local_device) ###################################################################### -- 2.39.5