From 66428de44398d154343bd34a40849800e4fe544e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 13 Aug 2024 22:35:30 +0200 Subject: [PATCH] Update. --- main.py | 213 +++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 197 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index 20acab3..8e06bb2 100755 --- a/main.py +++ b/main.py @@ -384,8 +384,6 @@ def one_epoch(model, quiz_machine, local_device=main_device): nb_train_samples, acc_train_loss = 0, 0.0 - hard_w_quizzes = [] - full_input, full_mask_loss = quiz_machine.data_input( args.nb_train_samples, model.train_c_quiz_bags ) @@ -427,13 +425,6 @@ def one_epoch(model, quiz_machine, local_device=main_device): run_tests(model, quiz_machine) - # threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values - # threshold = threshold[threshold.size(0) // 2] - - # model.hard_w_quizzes = torch.cat( - # [x[l >= threshold] for x, l in hard_w_quizzes], dim=0 - # ) - model.to(main_device) optimizer_to(model.optimizer, main_device) @@ -441,28 +432,28 @@ def one_epoch(model, quiz_machine, local_device=main_device): ###################################################################### -def model_transformer_hot(model): +def model_modifier_hot(model): model.temperature = args.temperature_hot # model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2)) -def model_transformer_cold(model): +def model_modifier_cold(model): model.temperature = args.temperature_cold # pass c_quizzes_procedure = [ - (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold), - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold), + (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot), + (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold), ] ###################################################################### def save_additional_results(model, models): - # Save generated quizzes with the successive steps + # Save generated quizzes with the successive generation steps recorder = [] @@ -660,6 +651,196 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): ) +###################################################################### + +from mygpt import ( + WithResidual, + CacheWrapper, + AddPositionalEncoding, + QKVAttention, + BracketedSequence, +) + + +class Thinker(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + f_len, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), + AddPositionalEncoding(len_max), + ) + + def trunk(depth): + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + ), + QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention_dropout=dropout, + ), + ), + WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ), + ] + + return nn.Sequential(*trunk_blocks) + + self.bottom_trunk = trunk(nb_blocks // 2) + + self.top_trunk = trunk(nb_blocks // 2) + + self.readout = CacheWrapper( + nn.Linear(in_features=dim_model, out_features=vocabulary_size) + ) + + self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model)) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, bs): + for m in self.modules(): + m.loss = 0 + + L = bs.x.size(1) // 3 + + bs = self.embedding(bs) + A_fA = BracketedSequence(bs.x[:, : 2 * L]) + B = BracketedSequence(bs.x[:, -L:]) + + bs = BracketedSequence( + torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1) + ) + bs = self.bottom_trunk(bs) + bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1)) + bs = self.top_trunk(bs) + bs = BracketedSequence(bs.x[:, f_len:, :]) + bs = self.readout(bs) + + for m in self.modules(): + if m is not self: + self.loss += m.loss + + return bs + + +if args.test == "func": + train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples) + test_input = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples) + + L = train_input.size(1) // 4 + f_len = 25 + + model = Thinker( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + f_len=20, + dropout=args.dropout, + ).to(main_device) + + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + for n_epoch in range(args.nb_epochs): + model.train() + + nb_train_samples, acc_train_loss = 0, 0.0 + + for input in tqdm.tqdm( + train_input.split(args.batch_size), + dynamic_ncols=True, + desc="training", + total=train_input.size(0) // args.batch_size, + ): + input = input.to(main_device) + + if nb_train_samples % args.batch_size == 0: + model.optimizer.zero_grad() + + output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x + targets = input[:, 3 * L :] + loss = F.cross_entropy(output.transpose(1, 2), targets) + acc_train_loss += loss.item() * input.size(0) + + nb_train_samples += input.size(0) + + loss.backward() + + if nb_train_samples % args.batch_size == 0: + model.optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_perplexity {n_epoch} model thinker {train_perplexity}") + + with torch.autograd.no_grad(): + model.eval() + + nb_test_samples, acc_test_loss = 0, 0.0 + + for input in tqdm.tqdm( + test_input.split(args.batch_size), + dynamic_ncols=True, + desc="testing", + total=test_input.size(0) // args.batch_size, + ): + input = input.to(main_device) + + output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x + targets = input[:, 3 * L :] + loss = F.cross_entropy(output.transpose(1, 2), targets) + acc_test_loss += loss.item() * input.size(0) + + nb_test_samples += input.size(0) + + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + + log_string(f"test_perplexity {n_epoch} model thinker {test_perplexity}") + + input = test_input[:128].clone().to(main_device) + + output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x + dist = torch.distributions.categorical.Categorical(logits=output) + input[:, 3 * L :] = dist.sample() + + ###################################################################### models = [] -- 2.39.5