From 055ea2c8f91a4ee2a528c3fed09e7ddb35eb0805 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 21 Aug 2024 16:40:04 +0200 Subject: [PATCH] Update. --- grids.py | 2 +- main.py | 117 +++++++++++++++++------------------------------- mygpt.py | 53 ++++++++++++++++------ quiz_machine.py | 30 ++++++++----- 4 files changed, 99 insertions(+), 103 deletions(-) diff --git a/grids.py b/grids.py index b12b4d6..c44e527 100755 --- a/grids.py +++ b/grids.py @@ -226,7 +226,7 @@ class Grids(problem.Problem): mask_ar = quizzes.new_zeros(quizzes.size()) S = self.height * self.width - a = mask_ar.reshape(mask_ar.size(0), 4, S + 1)[:, :, 1:] + a = mask_ar.view(mask_ar.size(0), 4, S + 1)[:, :, 1:] a[:, 0, :] = quad[0] a[:, 1, :] = quad[1] a[:, 2, :] = quad[2] diff --git a/main.py b/main.py index 78d01ff..033b5f6 100755 --- a/main.py +++ b/main.py @@ -360,6 +360,7 @@ def run_tests(model, quiz_machine, local_device=main_device): output = model( mygpt.BracketedSequence(input, ranks=mygpt.mask_ar_to_ranks(mask_ar)) ).x + loss_per_token = F.cross_entropy( output.transpose(1, 2), targets, reduction="none" ) @@ -479,6 +480,45 @@ c_quizzes_procedure = [ ###################################################################### +def model_proba_solutions(model, quizzes): + l = ( + quiz_machine.models_logprobas( + model, + quizzes, + ("A", "f_A", "B", "f_B"), + (0, 0, 0, 2), + (0, 0, 1, 0), + (0, 0, 0, 1), + ) + + quiz_machine.models_logprobas( + model, + quizzes, + ("f_A", "A", "f_B", "B"), + (0, 0, 0, 2), + (0, 0, 1, 0), + (0, 0, 0, 1), + ) + + quiz_machine.models_logprobas( + model, + quizzes, + ("B", "f_B", "A", "f_A"), + (0, 0, 0, 2), + (0, 0, 1, 0), + (0, 0, 0, 1), + ) + + quiz_machine.models_logprobas( + model, + quizzes, + ("f_B", "B", "f_A", "A"), + (0, 0, 0, 2), + (0, 0, 1, 0), + (0, 0, 0, 1), + ) + ) + + return l.exp() + + def save_additional_results(n_epoch, model, models, c_quizzes_procedure): # Save generated quizzes with the successive generation steps @@ -493,21 +533,7 @@ def save_additional_results(n_epoch, model, models, c_quizzes_procedure): # This is nb_quizzes x nb_models - l = [ - quiz_machine.models_logprobas( - model, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, c_quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, c_quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - for model in models - ] + l = [model_proba_solutions(model, c_quizzes) for model in models] seq_logprobas = torch.cat([x[:, None] for x in l], dim=1) probas = seq_logprobas.exp() @@ -549,25 +575,6 @@ def save_additional_results(n_epoch, model, models, c_quizzes_procedure): ###################################################################### -def model_proba_solutions(model, quizzes): - l = ( - quiz_machine.models_logprobas( - model, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - + quiz_machine.models_logprobas( - model, quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0) - ) - ) - - return l.exp() - - def create_c_quizzes( main_model, other_models, @@ -822,48 +829,6 @@ class Recorder(nn.Module): return input -###################################################################### - - -def save_generated_c_quizzes(model, filename, nb=64): - while sum([x.size(0) for x in record]) < nb: - model = models[torch.randint(len(models), (1,)).item()] - c_quizzes = quiz_machine.generate_c_quizzes( - 64, - model_for_generation=model, - procedure=c_quizzes_procedure, - ) - - p = quiz_machine.models_logprobas( - model, - c_quizzes, - ("A", "f_A", "B", "f_B"), - (1, 1, 1, 1), - temperature=1, - ).exp() - - p_hot = quiz_machine.models_logprobas( - model, - c_quizzes, - ("A", "f_A", "B", "f_B"), - (1, 1, 1, 1), - temperature=args.temperature_hot, - ).exp() - - to_keep = p_hot * torch.rand(p_hot.size(), device=p_hot.device) >= p - record.append(c_quizzes[to_keep]) - - print("NB_KEPT", sum([x.size(0) for x in record])) - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=c_quizzes, - ) - - log_string(f"wrote {filename}") - - ###################################################################### for n_epoch in range(current_epoch, args.nb_epochs): diff --git a/mygpt.py b/mygpt.py index cd5b580..dc00423 100755 --- a/mygpt.py +++ b/mygpt.py @@ -83,6 +83,12 @@ def mask_ar_to_ranks(mask_ar): return a +# mask_ar = torch.tensor([[ 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1]]) +# print(mask_ar) +# print(mask_ar_to_ranks(mask_ar)) +# exit(0) + + class BracketedSequence: def __init__(self, x, first=None, nb=None, ranks=None): self.x = x @@ -193,6 +199,7 @@ class QKVAttention(nn.Module): dim_qk, dim_v, nb_heads=1, + first_one=False, attention_dropout=0.0, ): super().__init__() @@ -201,6 +208,8 @@ class QKVAttention(nn.Module): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) self.attention_dropout = attention_dropout + self.first_one = first_one + self.record_attention = False self.w_q = randw(nb_heads, dim_qk, dim_in) @@ -243,19 +252,26 @@ class QKVAttention(nn.Module): t = torch.arange(x_q.size(1), device=a.device) - if bs_q.ranks is not None: - a = a.masked_fill( - ( - bs_q.ranks[:, None, bs_q.first : bs_q.first + bs_q.nb, None] - <= bs_kv.ranks[:, None, None, : bs_kv.first + bs_kv.nb] - ) - & ( - t[None, None, bs_q.first : bs_q.first + bs_q.nb, None] - != t[None, None, None, : bs_kv.first + bs_kv.nb] - ), - float("-inf"), + assert bs_q.ranks is not None + + # rank_forward = ( + # bs_q.ranks[:, None, bs_q.first : bs_q.first + bs_q.nb, None] + # >= bs_kv.ranks[:, None, None, : bs_kv.first + bs_kv.nb] + # ) + + if self.first_one: + rank_forward = ( + t[None, None, bs_q.first : bs_q.first + bs_q.nb, None] + <= t[None, None, None, : bs_kv.first + bs_kv.nb] + ) + else: + rank_forward = ( + t[None, None, bs_q.first : bs_q.first + bs_q.nb, None] + < t[None, None, None, : bs_kv.first + bs_kv.nb] ) + a = a.masked_fill(rank_forward, float("-inf")) + a = a.softmax(dim=3) if self.record_attention: @@ -269,7 +285,7 @@ class QKVAttention(nn.Module): self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o - return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb) + return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb, bs_q.ranks) ############################## @@ -347,7 +363,16 @@ class MyGPT(nn.Module): self.positional_encoding = AddPositionalEncoding(len_max) - trunk_blocks = [] + trunk_blocks = [ + QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + first_one=True, + attention_dropout=dropout, + ) + ] for b in range(nb_blocks): trunk_blocks += [ @@ -394,7 +419,7 @@ class MyGPT(nn.Module): for m in self.modules(): m.loss = 0 - bs = self.shifter(bs) + # bs = self.shifter(bs) bs = self.embedding(bs) bs = self.positional_encoding(bs) bs = self.trunk(bs) diff --git a/quiz_machine.py b/quiz_machine.py index 8cec909..db58461 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -39,17 +39,15 @@ def one_batch_masked_inplace_autoregression( indices_1 = list(((mask_ar == 1).long().sum(0) > 0).nonzero()) + [mask.size(1)] + ranks = mygpt.mask_ar_to_ranks(mask_ar) + if to_generate.min() > 0: model( - BracketedSequence(input, 0, to_generate.min()) + BracketedSequence(input, 0, to_generate.min(), ranks=ranks) ) # Needed to initialize the model's cache - s = to_generate.min() - for s, u in zip(indices_1[:-1], indices_1[1:]): - logits = model( - BracketedSequence(input, s, u - s, ranks=mygpt.mask_ar_to_ranks(mask_ar)) - ).x + logits = model(BracketedSequence(input, s, u - s, ranks=ranks)).x if deterministic_synthesis: t_next = logits.argmax(dim=2) @@ -90,10 +88,10 @@ class QuizMachine: # - struct, quad_generate, quad_noise, quad_loss self.train_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)), - (("B", "f_B", "A", "f_A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)), - (("f_B", "B", "f_A", "A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)), + (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), ] @@ -296,8 +294,9 @@ class QuizMachine: model, c_quizzes, struct, + mask_ar, + mask_noise, mask_loss, - mask_noise=None, temperature=1.0, device=None, ): @@ -322,13 +321,20 @@ class QuizMachine: for input, l in zip( c_quizzes.split(self.batch_size), + mask_ar.split(self.batch_size), seq_logprobas.split(self.batch_size), ): input = input.to(device) quiz_mask_loss = self.make_quiz_mask( input, struct=struct, mask=mask_loss ) - output = model(mygpt.BracketedSequence(input)).x / temperature + output = ( + model( + mygpt.BracketedSequence(input), + ranks=mygpt.mask_ar_to_ranks(mask_ar), + ).x + / temperature + ) l[...] = ( -F.cross_entropy(output.transpose(1, 2), input, reduction="none") * quiz_mask_loss -- 2.39.5