From 12184f604b37f36f07d7dcdd567b1c78f02c74db Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 26 Jul 2022 17:06:13 +0200 Subject: [PATCH] Fixed the size of w_o. --- mygpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mygpt.py b/mygpt.py index 121ad80..ab16e1e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -57,7 +57,7 @@ class QKVAttention(nn.Module): self.w_q = randw(nb_heads, dim_qk, dim_in) self.w_k = randw(nb_heads, dim_qk, dim_in) self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(dim_in, dim_v * nb_heads) + self.w_o = randw(dim_v * nb_heads, dim_in) def forward(self, x_q, x_kv = None): if x_kv is None: x_kv = x_q @@ -142,7 +142,7 @@ if __name__ == '__main__': model = MyGPT( vocabulary_size = vocabulary_size, - dim_model = 16, dim_keys = 50, dim_hidden = 100, + dim_model = 18, dim_keys = 50, dim_hidden = 100, nb_heads = 2, nb_blocks = 3, dropout = 0.1 ) -- 2.39.5