From f05fb29063a449c687e3f5c623d7430cde024107 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 13 Sep 2024 11:27:03 +0200 Subject: [PATCH] Update. --- attae.py | 40 +++++++++++++--------------------------- main.py | 9 ++------- 2 files changed, 15 insertions(+), 34 deletions(-) diff --git a/attae.py b/attae.py index bc90ed0..e201f60 100755 --- a/attae.py +++ b/attae.py @@ -16,17 +16,12 @@ class VaswaniPositionalEncoding(nn.Module): super().__init__() self.len_max = len_max - # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) - def forward(self, x): t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] - k = j % 2 - + k = j % 2 # works with float, weird pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) - y = x + pe - return y @@ -45,23 +40,22 @@ class WithResidual(nn.Module): ###################################################################### -def vanilla_attention(q, k, v): +def attention(q, k, v): a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) a = a.softmax(dim=3) y = torch.einsum("nhts,nhsd->nhtd", a, v) - y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) return y -vanilla_attention = torch.compile(vanilla_attention) +attention = torch.compile(attention) -# y = flex_attention(q, k, v, score_mod=noop) +###################################################################### class MHAttention(nn.Module): def __init__( self, - dim_in, + dim_model, dim_qk, dim_v, nb_heads=1, @@ -73,12 +67,10 @@ class MHAttention(nn.Module): return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) self.attention_dropout = attention_dropout - self.record_attention = False - - self.w_q = randw(nb_heads, dim_qk, dim_in) - self.w_k = randw(nb_heads, dim_qk, dim_in) - self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(nb_heads, dim_v, dim_in) + self.w_q = randw(nb_heads, dim_qk, dim_model) + self.w_k = randw(nb_heads, dim_qk, dim_model) + self.w_v = randw(nb_heads, dim_v, dim_model) + self.w_o = randw(nb_heads, dim_v, dim_model) def forward(self, x_q, x_kv=None): if x_kv is None: @@ -87,13 +79,7 @@ class MHAttention(nn.Module): q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k) v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v) - - def noop(score, b, h, q_idx, kv_idx): - return score - - y = vanilla_attention(q, k, v) - # y = flex_attention(q, k, v, score_mod=noop) - + y = attention(q, k, v) y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) return y @@ -112,7 +98,7 @@ class AttentionAE(nn.Module): nb_heads, nb_blocks, dropout=0.0, - len_max=1024, + len_max=1e5, ): super().__init__() @@ -123,7 +109,7 @@ class AttentionAE(nn.Module): nn.Dropout(dropout), ) - self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5) + self.positional_encoding = VaswaniPositionalEncoding(len_max) trunk_blocks = [] @@ -132,7 +118,7 @@ class AttentionAE(nn.Module): WithResidual( nn.LayerNorm((dim_model,)), MHAttention( - dim_in=dim_model, + dim_model=dim_model, dim_qk=dim_keys, dim_v=dim_model // nb_heads, nb_heads=nb_heads, diff --git a/main.py b/main.py index 0fea318..e090f86 100755 --- a/main.py +++ b/main.py @@ -1215,11 +1215,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): c_quizzes = torch.cat(record_c_quizzes, dim=0) agreements = torch.cat(record_agreements, dim=0) - return c_quizzes, agreements - - -def thread_generate_ae_c_quizzes(models, nb, record, local_device=main_device): - record.append(generate_ae_c_quizzes(models, nb, local_device)) + return c_quizzes.to("cpu"), agreements.to("cpu") ###################################################################### @@ -1381,8 +1377,7 @@ def multithread_execution(fun, arguments): else: return [ - torch.cat([x[k].to("cpu") for x in records], dim=0) - for k in range(len(records[0])) + torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0])) ] -- 2.39.5