From c11fc1c9f29aac2af7df409400a8045a66affb16 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 8 Sep 2024 11:59:58 +0200 Subject: [PATCH] Update. --- attae.py | 10 ++++++++-- main.py | 13 ++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/attae.py b/attae.py index 3a9f105..7bd4a44 100755 --- a/attae.py +++ b/attae.py @@ -102,7 +102,7 @@ class AttentionAE(nn.Module): assert dim_model % nb_heads == 0 self.embedding = nn.Sequential( - nn.Embedding(2 * vocabulary_size, dim_model), + nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout), ) @@ -166,5 +166,11 @@ if __name__ == "__main__": ) x = torch.randint(100, (10, 50)) - y = model(x) + + with torch.no_grad(): + model.eval() + x = torch.randint(100, (10, 50)) + y = model(x) + + print(y) diff --git a/main.py b/main.py index a4030ff..d90a3df 100755 --- a/main.py +++ b/main.py @@ -16,6 +16,8 @@ from torch.nn import functional as F import ffutils +import attae + import mygpt import sky, grids, quiz_machine @@ -373,7 +375,7 @@ def optimizer_to(optim, device): from mygpt import ( - WithResidual, + CachedWithResidual, CacheWrapper, CachedVaswaniPositionalEncoding, QKVAttention, @@ -394,7 +396,7 @@ class MultiEmbedding(nn.Module): def attention_block(dim_model, dim_keys, nb_heads, dropout): - return WithResidual( + return CachedWithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), ), @@ -409,7 +411,7 @@ def attention_block(dim_model, dim_keys, nb_heads, dropout): def ffw_block(dim_model, dim_hidden, nb_heads, dropout): - return WithResidual( + return CachedWithResidual( CacheWrapper( nn.LayerNorm((dim_model,)), nn.Linear(in_features=dim_model, out_features=dim_hidden), @@ -438,7 +440,8 @@ class MyAttentionAE(nn.Module): self.embedding = CacheWrapper( nn.Sequential( - MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout) + MultiEmbedding((vocabulary_size, 2), dim_model), + nn.Dropout(dropout), ), ) @@ -997,7 +1000,7 @@ models = [] for i in range(args.nb_models): model = MyAttentionAE( - # model = FunctionalAE( + # model = attae.AttentionAE( vocabulary_size=vocabulary_size, dim_model=args.dim_model, dim_keys=args.dim_keys, -- 2.39.5