From 03100792df9e52b739bbe4692bed6c4f6b575242 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jan 2024 14:49:59 +0100 Subject: [PATCH] Update. --- mygpt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mygpt.py b/mygpt.py index 9bacaff..c061eb4 100755 --- a/mygpt.py +++ b/mygpt.py @@ -545,6 +545,8 @@ class Caterpillar(nn.Module): torch.einsum("ntc,hec->nhet", X, self.w_G) + self.b_G[None, :, :, None] ).sigmoid() + G = F.dropout(G, self.attention_dropout, self.training) + V = torch.einsum("ntc,hdc->nhtd", X, self.w_V) K = torch.einsum("ntc,hdc->nhtd", X, self.w_K) -- 2.39.5