From: François Fleuret Date: Thu, 22 Aug 2024 16:06:38 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=215996db6be389e0a3847e6845f9eadf705f1c32;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 0564f3b..35b3cff 100755 --- a/grids.py +++ b/grids.py @@ -167,11 +167,11 @@ class Grids(problem.Problem): self.check_structure(quizzes, struct) return struct - def inject_noise(self, quizzes, noise, struct, mask): + def inject_noise(self, quizzes, noise, struct, quad): assert self.check_structure(quizzes, struct=struct) S = self.height * self.width - mask = torch.tensor(mask, device=quizzes.device) + mask = torch.tensor(quad, device=quizzes.device) mask = mask[None, :, None].expand(1, 4, S + 1).clone() mask[:, :, 0] = 0 mask = mask.reshape(1, -1).expand_as(quizzes) @@ -219,7 +219,7 @@ class Grids(problem.Problem): ).values def make_quiz_mask( - self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1) + self, quizzes, struct=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1) ): assert self.check_structure(quizzes, struct) @@ -227,10 +227,10 @@ class Grids(problem.Problem): S = self.height * self.width a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:] - a[:, 0, :] = mask[0] - a[:, 1, :] = mask[1] - a[:, 2, :] = mask[2] - a[:, 3, :] = mask[3] + a[:, 0, :] = quad[0] + a[:, 1, :] = quad[1] + a[:, 2, :] = quad[2] + a[:, 3, :] = quad[3] return ar_mask diff --git a/main.py b/main.py index 148a917..35ba763 100755 --- a/main.py +++ b/main.py @@ -542,7 +542,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): model, solved_c_quizzes[:, model.id], struct=("A", "f_A", "B", "f_B"), - mask=(0, 0, 0, 1), + quad=(0, 0, 0, 1), ) proba_own_solution[:, model.id] = model_proba_solutions( @@ -740,6 +740,207 @@ class Thinker(nn.Module): ###################################################################### +from mygpt import ( + WithResidual, + CacheWrapper, + AddPositionalEncoding, + QKVAttention, + BracketedSequence, +) + + +class MyAttentionVAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), + ) + + self.positional_encoding = AddPositionalEncoding(len_max) + + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + ), + QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention_dropout=dropout, + ), + ), + WithResidual( + CacheWrapper( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = CacheWrapper( + nn.Linear(in_features=dim_model, out_features=vocabulary_size) + ) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, bs): + bs = self.embedding(bs) + bs = self.positional_encoding(bs) + bs = self.trunk(bs) + bs = self.readout(bs) + return bs + + +def test_ae(local_device=main_device): + model = MyAttentionVAE( + vocabulary_size=vocabulary_size, + dim_model=args.dim_model, + dim_keys=args.dim_keys, + dim_hidden=args.dim_hidden, + nb_heads=args.nb_heads, + nb_blocks=args.nb_blocks, + dropout=args.dropout, + ).to(main_device) + + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + + model.to(local_device).train() + optimizer_to(model.optimizer, local_device) + + if args.schedule_free: + model.optimizer.train() + + for n_epoch in range(args.nb_epochs): + # ---------------------- + # Train + + model.train() + nb_train_samples, acc_train_loss = 0, 0.0 + + full_input, full_mask_loss = quiz_machine.data_input(args.nb_train_samples) + + src = zip( + full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) + ) + + for input, mask_loss in tqdm.tqdm( + src, + dynamic_ncols=True, + desc="training", + total=full_input.size(0) // args.batch_size, + ): + input = input.to(local_device) + mask_loss = mask_loss.to(local_device) + + if nb_train_samples % args.batch_size == 0: + model.optimizer.zero_grad() + + targets = input + input = (mask_loss == 0).long() * input + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), targets) + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) + loss.backward() + + if nb_train_samples % args.batch_size == 0: + model.optimizer.step() + + train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) + + log_string(f"train_loss {n_epoch} model AE {acc_train_loss/nb_train_samples}") + + # ---------------------- + # Test + + with torch.autograd.no_grad(): + model.eval() + + nb_test_samples, acc_test_loss = 0, 0.0 + + full_input, full_mask_loss = quiz_machine.data_input(args.nb_test_samples) + + src = zip( + full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) + ) + + for input, mask_loss in tqdm.tqdm( + src, + dynamic_ncols=True, + desc="testing", + total=full_input.size(0) // args.batch_size, + ): + input = input.to(local_device) + mask_loss = mask_loss.to(local_device) + targets = input + input = (mask_loss == 0).long() * input + output = model(mygpt.BracketedSequence(input)).x + loss = F.cross_entropy(output.transpose(1, 2), targets) + acc_test_loss += loss.item() * input.size(0) + nb_test_samples += input.size(0) + + log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}") + + input, mask_loss = quiz_machine.data_input(128) + input = input.to(local_device) + mask_loss = mask_loss.to(local_device) + targets = input + input = (mask_loss == 0).long() * input + logits = model(mygpt.BracketedSequence(input)).x + dist = torch.distributions.categorical.Categorical(logits=logits) + result = dist.sample() + L = input.size(1) // 4 + result[:, 0 * L] = input[:, 0 * L] + result[:, 1 * L] = input[:, 1 * L] + result[:, 2 * L] = input[:, 2 * L] + result[:, 3 * L] = input[:, 3 * L] + filename = f"prediction_ae_{n_epoch:04d}.png" + + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + filename, + quizzes=result, + ) + + log_string(f"wrote {filename}") + + +if args.test == "ae": + test_ae(local_device=main_device) + exit(0) + +###################################################################### + + def create_models(): models = [] @@ -1018,9 +1219,11 @@ if args.test == "entropy": procedure=c_quizzes_procedure, ) + filename = f"test_{n_epoch:04d}.png" + quiz_machine.problem.save_quizzes_as_image( args.result_dir, - f"test_{n_epoch:04d}.png", + filename, quizzes=input, ) @@ -1119,7 +1322,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): m = max(nb_c_quizzes_per_model) - if m >= args.nb_train_samples: + if m * args.c_quiz_multiplier >= args.nb_train_samples: break model = models[nb_c_quizzes_per_model.index(m)] diff --git a/quiz_machine.py b/quiz_machine.py index a0b007a..ceb527a 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -81,7 +81,7 @@ class QuizMachine: self.answer_len = None self.prompt_noise = prompt_noise - # struct, mask_generate, mask_noise, mask_loss + # struct, quad_generate, quad_noise, quad_loss self.train_structures = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), @@ -140,7 +140,7 @@ class QuizMachine: ###################################################################### - def data_input(self, nb_samples, c_quiz_bags, c_quiz_multiplier=1): + def data_input(self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1): if len(c_quiz_bags) > 0: c_quizzes = torch.cat(c_quiz_bags, dim=0) @@ -176,29 +176,29 @@ class QuizMachine: quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) if self.prompt_noise > 0.0: - for struct, _, mask_noise, mask_loss in self.train_structures: + for struct, _, quad_noise, quad_loss in self.train_structures: i = self.problem.indices_select(quizzes=quizzes, struct=struct) if i.any(): quizzes[i] = self.problem.inject_noise( - quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise + quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise ) quiz_mask_loss[i] = self.make_quiz_mask( - quizzes=quizzes[i], struct=struct, mask=mask_loss + quizzes=quizzes[i], struct=struct, quad=quad_loss ) return quizzes, quiz_mask_loss ###################################################################### - def make_quiz_mask(self, quizzes, struct, mask): + def make_quiz_mask(self, quizzes, struct, quad): assert struct in [s for s, _, _, _ in self.train_structures] - return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask) + return self.problem.make_quiz_mask(quizzes, struct=struct, quad=quad) ###################################################################### - def predict(self, model, quizzes, struct, mask): + def predict(self, model, quizzes, struct, quad): quizzes = quizzes.to(self.device) - ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask) + ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, quad=quad) result = quizzes * (1 - ar_mask) seq_logprobas = torch.zeros(quizzes.size(0), device=self.device) @@ -230,14 +230,14 @@ class QuizMachine: nb = 0 # We consider all the configurations that we train for - for struct, mask_generate, _, _ in self.test_structures: + for struct, quad_generate, _, _ in self.test_structures: i = self.problem.indices_select(quizzes=input, struct=struct) nb += i.long().sum() result[i], correct[i], _ = self.predict( - model=model, quizzes=input[i], struct=struct, mask=mask_generate + model=model, quizzes=input[i], struct=struct, quad=quad_generate ) - predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[ + predicted_parts[i] = torch.tensor(quad_generate, device=self.device)[ None, : ] solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 @@ -302,8 +302,8 @@ class QuizMachine: model, c_quizzes, struct, - mask_loss, - mask_noise=None, + quad_loss, + quad_noise=None, temperature=1.0, device=None, ): @@ -317,9 +317,9 @@ class QuizMachine: device=device, ) - # if self.prompt_noise > 0.0 and mask_noise is not None: + # if self.prompt_noise > 0.0 and quad_noise is not None: # c_quizzes = self.problem.inject_noise( - # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise + # c_quizzes, self.prompt_noise, struct=struct, quad=quad_noise # ) with torch.autograd.no_grad(): @@ -332,7 +332,7 @@ class QuizMachine: ): input = input.to(device) quiz_mask_loss = self.make_quiz_mask( - input, struct=struct, mask=mask_loss + input, struct=struct, quad=quad_loss ) output = model(mygpt.BracketedSequence(input)).x / temperature l[...] = ( @@ -352,21 +352,21 @@ class QuizMachine: c_quizzes = None for n_step, setup in enumerate(procedure): - s, m, mt = setup + struct, quad_generate, model_modifier = setup if c_quizzes is None: - c_quizzes = self.problem.create_empty_quizzes(nb, s) + c_quizzes = self.problem.create_empty_quizzes(nb, struct) c_quizzes = c_quizzes.to(self.device) - elif s != pred_s: - c_quizzes = self.problem.reconfigure(c_quizzes, s) - pred_s = s + elif struct != pred_struct: + c_quizzes = self.problem.reconfigure(c_quizzes, struct) + pred_struct = struct - if mt is not None: - mt(model_for_generation) + if model_modifier is not None: + model_modifier(model_for_generation) self.autoregression( model=model_for_generation, input=c_quizzes, - ar_mask=self.make_quiz_mask(c_quizzes, s, m), + ar_mask=self.make_quiz_mask(c_quizzes, struct, quad_generate), seq_logprobas=seq_logprobas, progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}", ) @@ -375,7 +375,9 @@ class QuizMachine: if recorder is not None: x = c_quizzes.clone() - t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1) + t = torch.tensor(quad_generate, device=x.device)[None, :].expand( + x.size(0), -1 + ) recorder.append( self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B")) )