From b7c4cba73ed743c9ba43c3f06e40da7304fca9c9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 30 Jul 2024 21:34:17 +0200 Subject: [PATCH] Update. --- main.py | 242 ++++++++++++++++++++++++++++++------------------ quiz_machine.py | 2 +- 2 files changed, 152 insertions(+), 92 deletions(-) diff --git a/main.py b/main.py index 19a3c29..7aeae98 100755 --- a/main.py +++ b/main.py @@ -19,6 +19,8 @@ import ffutils import mygpt import sky, grids, quiz_machine +from quiz_machine import one_batch_masked_inplace_autoregression + import threading, subprocess import torch.multiprocessing as mp @@ -773,7 +775,11 @@ def train_complexifier(model_gen, model_pred1, model_pred2): ###################################################################### -def train_autoencoder(): +models = [] + +for k in range(args.nb_gpts): + log_string(f"creating model {k} and its w_quizzes") + model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -781,130 +787,184 @@ def train_autoencoder(): dim_hidden=args.dim_hidden, nb_heads=args.nb_heads, nb_blocks=args.nb_blocks, - causal=False, + causal=True, dropout=args.dropout, - autoencoder_dim=args.autoencoder_dim, ).to(main_device) - test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) + model.main_test_accuracy = 0.0 + model.id = k - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes( + args.nb_train_samples + ) - nb_train_samples, acc_train_loss = 0, 0.0 + model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) - for n_epoch in range(args.nb_epochs): - train_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) - for input in tqdm.tqdm( - train_w_quizzes.split(args.batch_size), - dynamic_ncols=True, - desc="training AE", - total=train_w_quizzes.size(0) // args.batch_size, - ): - model.train() - l = input.size(1) // 4 - input = input[:, -l:].to(main_device) + models.append(model) - if nb_train_samples % args.batch_size == 0: - optimizer.zero_grad() +###################################################################### - z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device))) - output = model.decode(z_shape).x - loss = F.cross_entropy(output.transpose(1, 2), input) - acc_train_loss += loss.item() * input.size(0) +token_prolog_0 = vocabulary_size + 0 +token_prolog_1 = vocabulary_size + 1 +token_prolog_2 = vocabulary_size + 2 +generator_vocabulary_size = vocabulary_size + 3 - nb_train_samples += input.size(0) +generator = mygpt.MyGPT( + vocabulary_size=generator_vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + causal=True, + dropout=args.dropout, +).to(main_device) - loss.backward() +generator.main_test_accuracy = 0.0 - if nb_train_samples % args.batch_size == 0: - optimizer.step() - train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) +###################################################################### - log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}") - filename = f"autoencoder.pth" - torch.save( - model.state_dict(), - os.path.join(args.result_dir, filename), - ) - log_string(f"wrote {filename}") +def generate_c_quizz_with_generator(generator, quiz_machine): + c_quizzes = quiz_machine.problem.create_empty_quizzes( + args.batch_size, struct=("A", "f_A", "B", "f_B") + ) + i = F.one_hot( + torch.randint(args.nb_gpts, (c_quizzes.size(0),)), + num_classes=args.nb_gpts, + ) + prolog = token_prolog_0 * i + token_prolog_2 * (1 - i) + c_quizzes = torch.cat([prolog, c_quizzes], dim=1) + ar_mask = ( + torch.arange(c_quizzes.size(1), device=c_quizzes.device)[None, :] + >= args.nb_gpts + ).long() + + one_batch_masked_inplace_autoregression( + generator, + c_quizzes, + ar_mask, + seq_logproba, + deterministic_synthesis=False, + ) - with torch.autograd.no_grad(): - model.eval() - input = test_w_quizzes[0 * 128 : 1 * 128, -l:] - z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device))) - logits = model.decode(z_shape).x + return c_quizzes[:, args.nb_gpts :] - # dist = torch.distributions.categorical.Categorical(logits=logits) - # q = dist.sample() - q = logits.argmax(dim=-1) - q = q.reshape(q.size(0) // 2, 2, -1) - input = input.reshape(input.size(0) // 2, 2, -1) - q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1) - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - f"culture_ae_{n_epoch:04d}.png", - q, - ) +def batches_for_generator(generator=None, quiz_machine=None, device=main_device): + samples = [] - input1 = test_w_quizzes[1 * 128 : 2 * 128, -l:] - input2 = test_w_quizzes[2 * 128 : 3 * 128, -l:] - z_shape1 = model.encode(mygpt.BracketedSequence(input1.to(main_device))) - z_shape2 = model.encode(mygpt.BracketedSequence(input2.to(main_device))) - z_shape = ((z_shape1[0] + z_shape2[0]) * 0.5, z_shape1[1]) - logits = model.decode(z_shape).x + for _ in range(args.nb_train_samples // args.batch_size): + while sum([x.size(0) for x in samples]) < args.batch_size: + # Generate a bunch of quizzes - q = logits.argmax(dim=-1) - # q = q.reshape(q.size(0) // 2, 2, -1) - # input = input.reshape(input.size(0) // 2, 2, -1) - # q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1) + if generator is None: + # Either we start with the world quizzes + c_quizzes = quiz_machine.problem.generate_w_quizzes(args.batch_size) + else: + # Or we use the generator itself to generate them + c_quizzes = generate_c_quizz_with_generator(generator, quiz_machine) - q = q.reshape(q.size(0) // 4, -1) + # We remove the trivial ones + to_keep = quiz_machine.problem.trivial(c_quizzes) == False + c_quizzes = c_quizzes[to_keep] - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - f"culture_mix_ae_{n_epoch:04d}.png", - q, - ) + # If there are remaining ones, we compute the true prolog + # that indicates how the GPTs solve it - return model + if c_quizzes.size(0) > 0: + seq_logproba = quiz_machine.models_logprobas( + models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1) + ) + quiz_machine.models_logprobas( + models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1) + ) + probas = seq_logproba.exp() -# if args.autoencoder_dim > 0: -# ae = train_autoencoder() -# exit(0) + nu = probas <= args.proba_not_understands + u = probas >= args.proba_understands -###################################################################### + prolog = ( + (nu.long() * token_prolog_0) + + (u.long() * token_prolog_2) + + ((nu == False & u == False).long() * token_prolog_1) + ) + samples.append(torch.cat([prolog, c_quizzes], dim=1)) -models = [] + # Now we yield a batch -for k in range(args.nb_gpts): - log_string(f"creating model {k} and its w_quizzes") + x = torch.cat(samples, dim=0) + samples = [x[args.batch_size :]] - model = mygpt.MyGPT( - vocabulary_size=vocabulary_size, - dim_model=args.dim_model, - dim_keys=args.dim_keys, - dim_hidden=args.dim_hidden, - nb_heads=args.nb_heads, - nb_blocks=args.nb_blocks, - causal=True, - dropout=args.dropout, - ).to(main_device) + yield x[: args.batch_size] - model.main_test_accuracy = 0.0 - model.id = k - model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes( - args.nb_train_samples +def one_generator_epoch( + generator, quiz_machine=None, models=None, local_device=main_device +): + model.to(local_device).train() + + optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate) + + nb_train_samples, acc_train_loss = 0, 0.0 + + hard_w_quizzes = [] + + full_input, full_from_w = quiz_machine.data_input(generator, split="train") + src = zip(full_input.split(args.batch_size), full_from_w.split(args.batch_size)) + + for input, from_w in tqdm.tqdm( + src, + dynamic_ncols=True, + desc="training", + total=full_input.size(0) // args.batch_size, + ): + input = input.to(local_device) + + if nb_train_samples % args.batch_size == 0: + optimizer.zero_grad() + + targets = input + + output = generator(mygpt.BracketedSequence(input)).x + loss_per_token = F.cross_entropy( + output.transpose(1, 2), targets, reduction="none" + ) + loss = loss_per_token.mean() + acc_train_loss += loss.item() * input.size(0) + + loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) + if from_w.any(): + hard_w_quizzes.append( + (input[from_w].to("cpu"), loss_per_samples[from_w].to("cpu")) + ) + + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string( + f"train_perplexity {n_epoch} generator {generator.id} {train_perplexity}" ) - model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) + run_tests(generator, quiz_machine) + + threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values + threshold = threshold[threshold.size(0) // 2] + + generator.hard_w_quizzes = torch.cat( + [x[l >= threshold] for x, l in hard_w_quizzes], dim=0 + ) + + generator.to(main_device) - models.append(model) ###################################################################### diff --git a/quiz_machine.py b/quiz_machine.py index 90879ce..bad05ec 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -84,7 +84,7 @@ class QuizMachine: (("f_A", "A", "f_B", "B"), (0, 0, 0, 1)), (("B", "f_B", "A", "f_A"), (0, 0, 0, 1)), (("f_B", "B", "f_A", "A"), (0, 0, 0, 1)), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)), + # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1)), ] self.LOCK_C_QUIZZES = threading.Lock() -- 2.39.5