From de99e48d5c2dfb72e811f0bb1c2c09aa154af8b6 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 18 Jan 2024 13:06:27 +0100 Subject: [PATCH] Update. --- mygpt.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mygpt.py b/mygpt.py index 492a9bb..5451584 100755 --- a/mygpt.py +++ b/mygpt.py @@ -617,8 +617,6 @@ class Caterpillar(nn.Module): init_rec_V = self.rec_V[:, :, t0 - L : t0] init_rec_K = self.rec_K[:, :, t0 - L : t0] - # Associative scan - # Here there is a trick: Since the stack at position t is # computed by updating that at position t-L, the parallel # scan operates with a period of L. To do so we split the @@ -646,9 +644,16 @@ class Caterpillar(nn.Module): warnings.warn("gate dropout", RuntimeWarning) + # kill = ( + # torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout + # ).float() + kill = ( - torch.rand(G.size(), device=G.device) <= self.proba_gate_dropout - ).float() + torch.rand(N, H, R, t1 - t0, device=G.device).sort(dim=3).indices == 0 + ).cumsum(dim=3) + kill = kill * ( + torch.rand(N, H, R, 1, device=G.device) <= self.proba_gate_dropout + ) mask = 1 - kill -- 2.39.5