From 34d836636e246101a9d7af7b68f9eeb8efa1f39e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 19 Aug 2024 17:19:55 +0200 Subject: [PATCH] Update. --- main.py | 8 +++- mygpt.py | 120 ++++++++++++++++++++++++++----------------------------- 2 files changed, 62 insertions(+), 66 deletions(-) diff --git a/main.py b/main.py index 046514d..d98031e 100755 --- a/main.py +++ b/main.py @@ -342,7 +342,7 @@ def run_tests(model, quiz_machine, local_device=main_device): nb_samples_accumulated = 0 full_input, full_mask_loss = quiz_machine.data_input( - args.nb_test_samples, model.test_c_quiz_bags + args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier ) src = zip( full_input.split(args.batch_size), full_mask_loss.split(args.batch_size) @@ -370,10 +370,14 @@ def run_tests(model, quiz_machine, local_device=main_device): log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") + input, _ = quiz_machine.data_input( + 2000, model.test_c_quiz_bags, args.c_quiz_multiplier + ) + model.test_accuracy = quiz_machine.produce_results( n_epoch=n_epoch, model=model, - input=full_input[:2000], + input=input, result_dir=args.result_dir, ) diff --git a/mygpt.py b/mygpt.py index 041d28c..f716fe5 100755 --- a/mygpt.py +++ b/mygpt.py @@ -201,42 +201,48 @@ class QKVAttention(nn.Module): self.w_v = randw(nb_heads, dim_v, dim_in) self.w_o = randw(dim_v * nb_heads, dim_in) - def forward(self, bs_q): + def forward(self, bs_q, bs_kv=None): + if bs_kv is None: + bs_kv = bs_q + x_q = bs_q.x + x_kv = bs_kv.x - if bs_q.first == 0: - self.cache_k = x_q.new_zeros( - x_q.size(0), self.w_k.size(0), x_q.size(1), self.w_k.size(1) + if bs_kv.first == 0: + self.cache_k = x_kv.new_zeros( + x_kv.size(0), self.w_k.size(0), x_kv.size(1), self.w_k.size(1) ) - self.cache_v = x_q.new_zeros( - x_q.size(0), self.w_v.size(0), x_q.size(1), self.w_v.size(1) + self.cache_v = x_kv.new_zeros( + x_kv.size(0), self.w_v.size(0), x_kv.size(1), self.w_v.size(1) ) + + if bs_q.first == 0: self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) q = torch.einsum( "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q ) - self.cache_k[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum( - "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_k + self.cache_k[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum( + "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_k ) - self.cache_v[:, :, bs_q.first : bs_q.first + bs_q.nb] = torch.einsum( - "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_v + self.cache_v[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum( + "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_v ) a = torch.einsum( - "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_q.first + bs_q.nb] + "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_kv.first + bs_kv.nb] ) / math.sqrt(self.w_q.size(1)) if self.compute_attzero is not None: if bs_q.first == 0: self.cache_attzero = self.compute_attzero( torch.arange(x_q.size(1), device=q.device)[:, None], - torch.arange(x_q.size(1), device=q.device)[None, :], + torch.arange(x_kv.size(1), device=q.device)[None, :], )[None, None, :, :] a = a.masked_fill( self.cache_attzero[ - :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_q.first + bs_q.nb + :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_kv.first + bs_kv.nb ], float("-inf"), ) @@ -249,7 +255,7 @@ class QKVAttention(nn.Module): a = F.dropout(a, self.attention_dropout, self.training) y = torch.einsum( - "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_q.first + bs_q.nb] + "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_kv.first + bs_kv.nb] ).flatten(2) self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o @@ -277,6 +283,36 @@ class NoiseInjector(nn.Module): ############################## +class BlockSummarizer(nn.Module): + def __init__(self, nb_blocks, nb_tokens, dim_keys, dim_model): + self.nb_blocks = nb_blocks + self.static_q = nn.Parameter(nb_blocks - 1, nb_tokens, dim_keys) + + def compute_block_attzero(t_q, t_k): + block_size = t_q.size(0) + return (t_q // block_size) <= (t_k // block_size) + + self.qkv = QKVAttention( + dim_in=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + compute_attzero=compute_attzero, + attention_dropout=dropout, + ) + + def forward(self, bs): + pass + + +class ShiftByOne(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, bs): + return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + + class MyGPT(nn.Module): def __init__( self, @@ -287,7 +323,6 @@ class MyGPT(nn.Module): nb_heads, nb_blocks, compute_attzero=None, - autoencoder_dim=-1, dropout=0.0, len_max=1e5, ): @@ -297,11 +332,14 @@ class MyGPT(nn.Module): self.temperature = 1.0 + self.shifter = ShiftByOne() + self.embedding = nn.Sequential( CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - AddPositionalEncoding(len_max), ) + self.positional_encoding = AddPositionalEncoding(len_max) + trunk_blocks = [] for b in range(nb_blocks): @@ -338,26 +376,6 @@ class MyGPT(nn.Module): nn.Linear(in_features=dim_model, out_features=vocabulary_size) ) - # ------------------------------------------------------- - if autoencoder_dim > 0: - self.encoder = nn.Sequential( - *( - trunk_blocks[: nb_blocks // 2] - + [EncoderHead(dim_model, autoencoder_dim)] - ) - ) - - self.decoder = nn.Sequential( - *( - [ - DecoderBottom(autoencoder_dim, dim_model), - AddPositionalEncoding(len_max), - ] - + trunk_blocks[nb_blocks // 2 :] - ) - ) - # ------------------------------------------------------- - with torch.no_grad(): for m in self.modules(): if isinstance(m, nn.Embedding): @@ -370,8 +388,9 @@ class MyGPT(nn.Module): for m in self.modules(): m.loss = 0 - bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) + bs = self.shifter(bs) bs = self.embedding(bs) + bs = self.positional_encoding(bs) bs = self.trunk(bs) bs = self.readout(bs) bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature @@ -381,33 +400,6 @@ class MyGPT(nn.Module): return bs - def encode(self, bs): - bs = self.embedding(bs) - z = self.encoder(bs) - return z - - def decode(self, z_shape): - bs = self.decoder(z_shape) - bs = self.readout(bs) - return bs - - def partial_forward(self, bs, start_layer=None, end_layer=None): - if start_layer is None: - # print(f"GENERATE {bs.first} {bs.first+bs.nb}") - bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) - bs = self.embedding(bs) - if end_layer is not None: - return self.trunk[:end_layer](bs) - else: - bs = self.trunk(bs) - bs = self.readout(bs) - return bs - else: - bs = self.trunk[start_layer:](bs) - bs = self.trunk(bs) - bs = self.readout(bs) - return bs - def reset_transformations(self): self.temperature = 1.0 for m in self.modules(): -- 2.39.5