From: Francois Fleuret Date: Thu, 28 Jul 2022 19:53:21 +0000 (+0200) Subject: OCDC X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=7b6a1a4f12459fd18a2006fa8f11589f2b2cd87b;p=mygpt.git OCDC --- diff --git a/mygpt.py b/mygpt.py index 7c4e06d..212e1a5 100755 --- a/mygpt.py +++ b/mygpt.py @@ -37,16 +37,14 @@ class PositionalEncoding(nn.Module): j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] k = j%2 pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k) - return x + pe # Let broadcasting to its job + return x + pe ############################## class QKVAttention(nn.Module): - def __init__( - self, - dim_in, dim_qk, dim_v, - nb_heads = 1, causal = False, attention_dropout = 0.0 - ): + def __init__(self, + dim_in, dim_qk, dim_v, + nb_heads = 1, causal = False, attention_dropout = 0.0): super().__init__() def randw(*d): @@ -88,7 +86,8 @@ class MyGPT(nn.Module): def __init__(self, vocabulary_size, dim_model, dim_keys, dim_hidden, - nb_heads, nb_blocks, dropout = 0.): + nb_heads, nb_blocks, + dropout = 0.0, len_max = 1e5): super().__init__() @@ -97,7 +96,7 @@ class MyGPT(nn.Module): self.embedding = nn.Sequential( nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout), - PositionalEncoding(len_max = 1e5), + PositionalEncoding(len_max), ) trunk_blocks = [ ]