From 9b4c21698ef7461f7e3bcc403868f0158ee2e20b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 17 Aug 2024 10:42:42 +0200 Subject: [PATCH] Update. --- main.py | 169 +++++++++++++++--------------------------------- quiz_machine.py | 4 +- 2 files changed, 56 insertions(+), 117 deletions(-) diff --git a/main.py b/main.py index 915e10e..92bc05f 100755 --- a/main.py +++ b/main.py @@ -63,6 +63,8 @@ parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) parser.add_argument("--learning_rate", type=float, default=5e-4) +parser.add_argument("--lambda_H", type=float, default=0.0) + parser.add_argument("--schedule_free", action="store_true", default=False) # ---------------------------------- @@ -404,10 +406,20 @@ def one_epoch(model, quiz_machine, local_device=main_device): targets = input output = model(mygpt.BracketedSequence(input)).x + loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" ) + + # warnings.warn("entropy masking", RuntimeWarning) + # l = output.transpose(1, 2).log_softmax(dim=1) + # H = -(l * l.exp()).sum(dim=1) + # M = (H >= -math.log(0.99) / H.size(1)).long() + # print(H, M) + # loss_per_token = loss_per_token * M + loss = (loss_per_token * mask_loss).mean() + model.loss + acc_train_loss += loss.item() * input.size(0) loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) @@ -782,98 +794,6 @@ class Thinker(nn.Module): return bs -if args.test == "func": - test_input = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) - - L = test_input.size(1) // 4 - f_len = 50 - - model = Thinker( - vocabulary_size=vocabulary_size, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - f_len=f_len, - dropout=args.dropout, - ).to(main_device) - - model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - - for n_epoch in range(args.nb_epochs): - model.train() - - train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) - - nb_train_samples, acc_train_loss = 0, 0.0 - - for input in tqdm.tqdm( - train_input.split(args.batch_size), - dynamic_ncols=True, - desc="training", - total=train_input.size(0) // args.batch_size, - ): - input = input.to(main_device) - - if nb_train_samples % args.batch_size == 0: - model.optimizer.zero_grad() - - output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x - targets = input[:, 3 * L :] - loss = F.cross_entropy(output.transpose(1, 2), targets) - acc_train_loss += loss.item() * input.size(0) - - nb_train_samples += input.size(0) - - loss.backward() - - if nb_train_samples % args.batch_size == 0: - model.optimizer.step() - - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - - log_string(f"train_perplexity {n_epoch} model thinker {train_perplexity}") - - with torch.autograd.no_grad(): - model.eval() - - nb_test_samples, acc_test_loss = 0, 0.0 - - for input in tqdm.tqdm( - test_input.split(args.batch_size), - dynamic_ncols=True, - desc="testing", - total=test_input.size(0) // args.batch_size, - ): - input = input.to(main_device) - - output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x - targets = input[:, 3 * L :] - loss = F.cross_entropy(output.transpose(1, 2), targets) - acc_test_loss += loss.item() * input.size(0) - - nb_test_samples += input.size(0) - - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - - log_string(f"test_perplexity {n_epoch} model thinker {test_perplexity}") - - input = test_input[:128].clone().to(main_device) - - output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x - dist = torch.distributions.categorical.Categorical(logits=output) - input[:, 3 * L + 1 :] = dist.sample()[:, 1:] - - problem.save_quizzes_as_image( - args.result_dir, - f"thinker_prediction_{n_epoch:04d}.png", - quizzes=input, - # predicted_parts=predicted_parts, - # correct_parts=correct_parts, - ) - - ###################################################################### models = [] @@ -913,7 +833,7 @@ for k in range(args.nb_gpts): model.test_accuracy = 0.0 model.best_test_accuracy = 0.0 - + model.best_dict = copy.deepcopy(model.state_dict()) models.append(model) ###################################################################### @@ -1071,25 +991,10 @@ if args.test == "mlp": exit(0) ###################################################################### -###################################################################### -if args.test == "reject": - record = [] - - c_quizzes_procedure = [ - (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot), - (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold), - (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold), - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), - (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold), - (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold), - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), - (("f_B", "B", "f_A", "A"), (0, 0, 1, 1), model_modifier_cold), - (("f_B", "f_A", "A", "B"), (0, 0, 0, 1), model_modifier_cold), - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), - ] - while sum([x.size(0) for x in record]) < 64: +def save_generated_c_quizzes(model, filename, nb=64): + while sum([x.size(0) for x in record]) < nb: model = models[torch.randint(len(models), (1,)).item()] c_quizzes = quiz_machine.generate_c_quizzes( 64, @@ -1118,8 +1023,6 @@ if args.test == "reject": print("NB_KEPT", sum([x.size(0) for x in record])) - filename = f"sampling_with_rejection.png" - quiz_machine.problem.save_quizzes_as_image( args.result_dir, filename, @@ -1128,6 +1031,40 @@ if args.test == "reject": log_string(f"wrote {filename}") + +###################################################################### + +if args.test == "entropy": + model = models[0] + model.to(main_device) + + log_string("starting testing entropy maximization") + + train_input = quiz_machine.generate_c_quizzes( + 1000, + model_for_generation=model, + procedure=c_quizzes_procedure, + ) + + for n_epoch in range(10): + nb_train_samples, acc_train_loss = 0, 0.0 + + for input in train_input.split(args.batch_size): + input = input.to(main_device) + output = model(mygpt.BracketedSequence(input)).x + loss = output.log_softmax(dim=1).mean() + + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) + + model.optimizer.zero_grad() + loss.backward() + model.optimizer.step() + + log_string( + f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}" + ) + exit(0) ###################################################################### @@ -1187,7 +1124,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): ################################################## # Select, improve, and eval the worst model(s) - if total_time_training_models < total_time_generating_c_quizzes: + if total_time_training_models <= total_time_generating_c_quizzes: ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] @@ -1212,6 +1149,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): total_time_training_models += time.perf_counter() - start_time + for model in weakest_models: + save_additional_results(n_epoch, model, models, c_quizzes_procedure) + # Save the models to disk for model in models: @@ -1230,9 +1170,6 @@ for n_epoch in range(current_epoch, args.nb_epochs): ) log_string(f"wrote {filename}") - for model in weakest_models: - save_additional_results(n_epoch, model, models, c_quizzes_procedure) - ###################################################################### if args.log_command is not None: diff --git a/quiz_machine.py b/quiz_machine.py index 3c4a865..98e0ea5 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -338,7 +338,8 @@ class QuizMachine: c_quizzes = None - for s, m, mt in procedure: + for n_step, setup in enumerate(procedure): + s, m, mt = setup if c_quizzes is None: c_quizzes = self.problem.create_empty_quizzes(nb, s) c_quizzes = c_quizzes.to(self.device) @@ -354,6 +355,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_quiz_mask(c_quizzes, s, m), seq_logprobas=seq_logprobas, + progress_bar_desc=f"autoregression {n_step}/{len(procedure)}", ) model_for_generation.reset_transformations() -- 2.39.5