From 89f5507c3da45f1a75e5af1d4762145798d2420b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 19 Aug 2024 19:50:13 +0200 Subject: [PATCH] Update. --- main.py | 475 ++++++++++-------------------------------------- quiz_machine.py | 17 +- 2 files changed, 98 insertions(+), 394 deletions(-) diff --git a/main.py b/main.py index d98031e..1cbff39 100755 --- a/main.py +++ b/main.py @@ -61,8 +61,6 @@ parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) -parser.add_argument("--c_quiz_multiplier", type=int, default=1) - parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--lambda_H", type=float, default=0.0) @@ -342,7 +340,7 @@ def run_tests(model, quiz_machine, local_device=main_device): nb_samples_accumulated = 0 full_input, full_mask_loss = quiz_machine.data_input( - args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier + args.nb_test_samples, model.test_c_quiz_bags ) src = zip( full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) @@ -370,9 +368,7 @@ def run_tests(model, quiz_machine, local_device=main_device): log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") - input, _ = quiz_machine.data_input( - 2000, model.test_c_quiz_bags, args.c_quiz_multiplier - ) + input, _ = quiz_machine.data_input(2000, model.test_c_quiz_bags) model.test_accuracy = quiz_machine.produce_results( n_epoch=n_epoch, @@ -395,7 +391,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): nb_train_samples, acc_train_loss = 0, 0.0 full_input, full_mask_loss = quiz_machine.data_input( - args.nb_train_samples, model.train_c_quiz_bags, args.c_quiz_multiplier + args.nb_train_samples, model.train_c_quiz_bags ) src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)) @@ -561,26 +557,21 @@ def model_proba_solutions(model, quizzes): return l.exp() -def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): +def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_for_test): nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models) nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate start_time = time.perf_counter() - for model in models: - model.recorded_c_quizzes = [] - - teaching_count = torch.zeros(len(models), len(models), dtype=torch.int64) + recorded = [] while nb_validated < nb_to_validate: - model_for_generation = models[torch.randint(len(models), (1,)).item()] - # We generate quizzes with a procedure that injects some # structured noise c_quizzes = quiz_machine.generate_c_quizzes( nb_to_generate_per_iteration, - model_for_generation=model, + model_for_generation=main_model, procedure=c_quizzes_procedure, ) @@ -593,57 +584,48 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): c_quizzes = c_quizzes[to_keep] - # Compute the responses of all the models on the c_quizzes, - # and their proba estimates of their responses + # Keep only the quizzes that the main model cannot solve - solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone() + solved_c_quizzes = c_quizzes.clone() - proba_own_solution = torch.zeros( - c_quizzes.size(0), len(models), device=solved_c_quizzes.device + main_solution, _, _ = quiz_machine.predict( + main_model, + solved_c_quizzes, + struct=("A", "f_A", "B", "f_B"), + mask=(0, 0, 0, 1), ) - for model in models: - (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict( - model, - solved_c_quizzes[:, model.id], - struct=("A", "f_A", "B", "f_B"), - mask=(0, 0, 0, 1), - ) + keep = ( + model_proba_solutions(main_model, main_solution) + < args.proba_not_understands + ) + c_quizzes = c_quizzes[keep] + + # If there are some quizzes that the main model cannot solve, + # pick the most confident solution - proba_own_solution[:, model.id] = model_proba_solutions( - model, solved_c_quizzes[:, model.id] + if c_quizzes.size(0) > 0: + solution = c_quizzes.clone() + c_quizzes_proba = torch.zeros( + solution.size(0), dtype=torch.float32, device=solution.device ) - # Now for every model not confident of its response, we pick - # the most consistent from a model which is confident - - for s in range(proba_own_solution.size(0)): - # At least one GPT does not understand at all - if proba_own_solution[s, :].min() < args.proba_not_understands: - dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands - nb_fails = dont_get_this_quiz.long().sum() - # At most max_fail_to_validate do not understand (default 3/5) - if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate: - for model in models: - # If a GPT does not get that quiz - if dont_get_this_quiz[model.id]: - assert ( - proba_own_solution[s, model.id] < args.proba_understands - ) - # Look at its estimate of the others'solutions - proba_other_solutions = model_proba_solutions( - model, solved_c_quizzes[s] - ) - # Randomize a bit the orders for the frequent P=1 - proba_other_solutions += ( - torch.rand(proba_other_solutions.size()) * 1e-6 - ) - # Remove the under threshold confidence solutions - proba_other_solutions[dont_get_this_quiz] = -1 - i = proba_other_solutions.argmax() - model.recorded_c_quizzes.append(solved_c_quizzes[s, i]) - teaching_count[i, model.id] += 1 - nb_validated += 1 + for model in other_models: + solution, _, _ = quiz_machine.predict( + model, + solution, + struct=("A", "f_A", "B", "f_B"), + mask=(0, 0, 0, 1), + ) + + probas = model_proba_solutions(model, solution) + keep = probas >= c_quizzes_proba + c_quizzes = solution[keep] + c_quizzes_proba[keep] = probas[keep] + + keep = c_quizzes_proba >= args.proba_understands + recorded.append(c_quizzes_proba[keep]) + nb_validated += keep.long().sum() duration = time.perf_counter() - start_time @@ -662,146 +644,29 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): f"keep c_quizzes model {model_for_generation.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%" ) - for s in range(teaching_count.size(0)): - o = [x.item() for x in teaching_count[s]] - log_string(f"teacher model {s} to {o}") + # Save some images - for model in models: - new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0) - - if new_bag.size(0) > 0: - n = (new_bag.size(0) * nb_for_train) // (nb_for_train + nb_for_test) - if n > 0: - model.train_c_quiz_bags.append(new_bag[:n]) - if n < new_bag.size(0): - model.test_c_quiz_bags.append(new_bag[n:]) - - c_quizzes = new_bag[:128] - - l = [model_proba_solutions(model, c_quizzes) for model in models] - probas = torch.cat([x[:, None] for x in l], dim=1) - comments = [] - - for l in probas: - comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) - - filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png" - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, filename, c_quizzes, comments=comments - ) + c_quizzes = torch.cat(recorded, dim=0) - log_string( - f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}" - ) + l = [ + model_proba_solutions(model, c_quizzes) for model in [main_model] + other_models + ] + probas = torch.cat([x[:, None] for x in l], dim=1) + comments = [] + for l in probas: + comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l])) + filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png" + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, filename, c_quizzes[:128], comments=comments + ) -###################################################################### -from mygpt import ( - WithResidual, - CacheWrapper, - AddPositionalEncoding, - QKVAttention, - BracketedSequence, +log_string( + f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}" ) -class Thinker(nn.Module): - def __init__( - self, - vocabulary_size, - dim_model, - dim_keys, - dim_hidden, - nb_heads, - nb_blocks, - f_len, - dropout=0.0, - len_max=1e5, - ): - super().__init__() - - assert dim_model % nb_heads == 0 - - self.embedding = nn.Sequential( - CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - AddPositionalEncoding(len_max), - ) - - def trunk(depth): - trunk_blocks = [] - - for b in range(nb_blocks): - trunk_blocks += [ - WithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - ), - QKVAttention( - dim_in=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - attention_dropout=dropout, - ), - ), - WithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), - ), - ), - ] - - return nn.Sequential(*trunk_blocks) - - self.bottom_trunk = trunk(nb_blocks // 2) - - self.top_trunk = trunk(nb_blocks // 2) - - self.readout = CacheWrapper( - nn.Linear(in_features=dim_model, out_features=vocabulary_size) - ) - - self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model)) - - with torch.no_grad(): - for m in self.modules(): - if isinstance(m, nn.Embedding): - m.weight.normal_(mean=0, std=2e-2) - elif isinstance(m, nn.LayerNorm): - m.bias.zero_() - m.weight.fill_(1.0) - - def forward(self, bs): - for m in self.modules(): - m.loss = 0 - - L = bs.x.size(1) // 3 - - bs = self.embedding(bs) - A_fA = BracketedSequence(bs.x[:, : 2 * L]) - B = BracketedSequence(bs.x[:, -L:]) - - bs = BracketedSequence( - torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1) - ) - bs = self.bottom_trunk(bs) - bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1)) - bs = self.top_trunk(bs) - bs = BracketedSequence(bs.x[:, f_len:, :]) - bs = self.readout(bs) - - for m in self.modules(): - if m is not self: - self.loss += m.loss - - return bs - - ###################################################################### models = [] @@ -855,20 +720,12 @@ for k in range(args.nb_gpts): model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) model.test_accuracy = 0.0 - model.best_test_accuracy = 0.0 - model.best_dict = copy.deepcopy(model.state_dict()) models.append(model) ###################################################################### current_epoch = 0 -# We balance the computing time between training the models and -# generating c_quizzes - -total_time_generating_c_quizzes = 0 -total_time_training_models = 0 - if args.resume: for model in models: filename = f"gpt_{model.id:03d}.pth" @@ -878,8 +735,6 @@ if args.resume: model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) model.test_accuracy = d["test_accuracy"] - model.best_test_accuracy = d["best_test_accuracy"] - model.best_dict = d["best_dict"] model.train_c_quiz_bags = d["train_c_quiz_bags"] model.test_c_quiz_bags = d["test_c_quiz_bags"] log_string(f"successfully loaded {filename}") @@ -892,8 +747,6 @@ if args.resume: state = torch.load(os.path.join(args.result_dir, filename)) log_string(f"successfully loaded {filename}") current_epoch = state["current_epoch"] - total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"] - total_time_training_models = state["total_time_training_models"] except FileNotFoundError: log_string(f"cannot find {filename}") pass @@ -950,69 +803,6 @@ class Recorder(nn.Module): return input -if args.test == "mlp": - model = models[0] - tape_input, tape_output = [], [] - L = len(model.trunk) - model.trunk.insert(L // 2 + 1, Recorder(tape_output)) - model.trunk.insert(L // 2, Recorder(tape_input)) - - mlp = nn.Sequential( - nn.Linear(args.dim_model, args.dim_model), - nn.ReLU(), - nn.Linear(args.dim_model, args.dim_model), - nn.ReLU(), - nn.Linear(args.dim_model, 8 * args.dim_model), - Folder(), - Unfolder(404, 8 * args.dim_model), - nn.Linear(8 * args.dim_model, args.dim_model), - nn.ReLU(), - nn.Linear(args.dim_model, args.dim_model), - nn.ReLU(), - nn.Linear(args.dim_model, args.dim_model), - ).to(main_device) - - mlp.optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4) - - for n_epoch in range(args.nb_epochs): - train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) - - tape_input.clear() - tape_output.clear() - - with torch.autograd.no_grad(): - model.to(main_device).eval() - for input in train_input.split(args.batch_size): - input = input.to(main_device) - output = model(mygpt.BracketedSequence(input)).x - - train_input = torch.cat([bs.x for bs in tape_input], dim=0) - train_targets = torch.cat([bs.x for bs in tape_output], dim=0) - - nb_train_samples, acc_train_loss = 0, 0.0 - src = zip( - train_input.split(args.batch_size), train_targets.split(args.batch_size) - ) - for input, targets in tqdm.tqdm( - src, - dynamic_ncols=True, - desc="train", - total=train_input.size(0) // args.batch_size, - ): - input = input.to(main_device) - output = mlp(input) - loss = F.mse_loss(output, targets) + output.abs().sum() - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) - - mlp.optimizer.zero_grad() - loss.backward() - mlp.optimizer.step() - - log_string(f"mlp_loss {n_epoch} train {acc_train_loss/nb_train_samples}") - - exit(0) - ###################################################################### @@ -1057,67 +847,9 @@ def save_generated_c_quizzes(model, filename, nb=64): ###################################################################### - -if args.test == "entropy": - model = models[0] - model.to(main_device) - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - - log_string("starting testing entropy maximization") - - for n_epoch in range(100): - input = quiz_machine.generate_c_quizzes( - 128, - model_for_generation=model, - procedure=c_quizzes_procedure, - ) - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - f"test_{n_epoch:04d}.png", - quizzes=input, - ) - - log_string(f"wrote {filename}") - - with torch.no_grad(): - for p in model.parameters(): - p += torch.randn(p.size(), device=p.device) * 1e-3 - - # nb_train_samples, acc_train_loss = 0, 0.0 - - # for k in range(1000 // args.batch_size): - # input = quiz_machine.generate_c_quizzes( - # args.batch_size, - # model_for_generation=model, - # procedure=[(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), None)], - # ) - - # input = input.to(main_device) - # targets = input - # output = model(mygpt.BracketedSequence(input)).x - # loss = -F.cross_entropy(output.transpose(1, 2), targets) - # acc_train_loss += loss.item() * input.size(0) - # nb_train_samples += input.size(0) - - # optimizer.zero_grad() - # loss.backward() - # optimizer.step() - - # log_string( - # f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}" - # ) - - exit(0) - -###################################################################### - for n_epoch in range(current_epoch, args.nb_epochs): state = { "current_epoch": n_epoch, - "total_time_training_models": total_time_training_models, - "total_time_generating_c_quizzes": total_time_generating_c_quizzes, } filename = "state.pth" torch.save(state, os.path.join(args.result_dir, filename)) @@ -1128,84 +860,71 @@ for n_epoch in range(current_epoch, args.nb_epochs): cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models]) log_string(f"current_test_accuracies {cta}") - cta = " ".join([f"{float(m.best_test_accuracy):.04f}" for m in models]) - log_string(f"current_best_test_accuracies {cta}") - ################################################## - for model in models: - if model.test_accuracy >= args.accuracy_to_make_c_quizzes: - log_string( - f"storing_best model {model.id} accuracy {model.best_test_accuracy} -> {model.test_accuracy}" - ) - model.best_dict = copy.deepcopy(model.state_dict()) - model.best_test_accuracy = model.test_accuracy - - # we restart - if total_time_generating_c_quizzes == 0: - total_time_training_models = 0 - - if ( - min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes - and total_time_training_models >= total_time_generating_c_quizzes - ): - for model in models: - model.current_dict = copy.deepcopy(model.state_dict()) - model.load_state_dict(model.best_dict) - - start_time = time.perf_counter() + if min([m.test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: record_new_c_quizzes( models, quiz_machine, args.nb_new_c_quizzes_for_train, args.nb_new_c_quizzes_for_test, ) - total_time_generating_c_quizzes += time.perf_counter() - start_time - # Force one epoch of training for model in models: - model.load_state_dict(model.current_dict) + new_model = mygpt.MyGPT( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + compute_attzero=compute_causal_attzero, + dropout=args.dropout, + ).to(main_device) + model.load_state_dict(new_model.state_dict()) + model.test_accuracy = 0.0 + model.best_test_accuracy = 0.0 + model.best_dict = copy.deepcopy(model.state_dict()) ################################################## # Select, improve, and eval the worst model(s) - if total_time_training_models <= total_time_generating_c_quizzes: - ranked_models = sorted( - models, - # This ugly recipe will pick the worst if there some below - # args.accuracy_to_make_c_quizzes or one at random if they - # are all above - key=lambda m: float( - m.test_accuracy - if m.test_accuracy < args.accuracy_to_make_c_quizzes - else args.accuracy_to_make_c_quizzes + torch.rand(1).item() - ), - ) + ranked_models = sorted( + models, + # This ugly recipe will pick the worst if there some below + # args.accuracy_to_make_c_quizzes or one at random if they + # are all above + key=lambda m: float( + m.test_accuracy + if m.test_accuracy < args.accuracy_to_make_c_quizzes + else args.accuracy_to_make_c_quizzes + torch.rand(1).item() + ), + ) - weakest_models = ranked_models[: len(gpus)] + weakest_models = ranked_models[: len(gpus)] - threads = [] + threads = [] - start_time = time.perf_counter() + start_time = time.perf_counter() - for gpu, model in zip(gpus, weakest_models): - log_string(f"training model {model.id}") + for gpu, model in zip(gpus, weakest_models): + log_string(f"training model {model.id}") - t = threading.Thread( - target=one_epoch, daemon=True, args=(model, quiz_machine, gpu) - ) + t = threading.Thread( + target=one_epoch, daemon=True, args=(model, quiz_machine, gpu) + ) - threads.append(t) + threads.append(t) - t.start() + t.start() - for t in threads: - t.join() + for t in threads: + t.join() - total_time_training_models += time.perf_counter() - start_time + total_time_training_models += time.perf_counter() - start_time - for model in weakest_models: - save_additional_results(n_epoch, model, models, c_quizzes_procedure) + for model in weakest_models: + save_additional_results(n_epoch, model, models, c_quizzes_procedure) # Save the models to disk diff --git a/quiz_machine.py b/quiz_machine.py index a0b007a..1acd7ad 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -87,8 +87,6 @@ class QuizMachine: (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - # (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - # (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), ] @@ -140,23 +138,10 @@ class QuizMachine: ###################################################################### - def data_input(self, nb_samples, c_quiz_bags, c_quiz_multiplier=1): + def data_input(self, nb_samples, c_quiz_bags): if len(c_quiz_bags) > 0: c_quizzes = torch.cat(c_quiz_bags, dim=0) - if c_quiz_multiplier > 1: - n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0)) - body = c_quizzes.repeat(n, 1) - if n < c_quiz_multiplier: - tail = c_quizzes[ - torch.randperm(c_quizzes.size(0))[ - : nb_samples // 2 - body.size(0) - ] - ] - c_quizzes = torch.cat([body, tail], dim=0) - else: - c_quizzes = body - if c_quizzes.size(0) > nb_samples // 2: i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] c_quizzes = c_quizzes[i] -- 2.39.5