From: François Fleuret Date: Fri, 20 Sep 2024 11:40:39 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=aea004528da985523c70f660f3d4afd9814abc18;p=culture.git Update. --- diff --git a/attae.py b/attae.py index 1e5e122..c04c5d3 100755 --- a/attae.py +++ b/attae.py @@ -101,7 +101,6 @@ class AttentionAE(nn.Module): dim_hidden, nb_heads, nb_blocks, - attention=vanilla_attention, dropout=0.0, len_max=1e5, ): @@ -127,7 +126,7 @@ class AttentionAE(nn.Module): dim_qk=dim_keys, dim_v=dim_model // nb_heads, nb_heads=nb_heads, - attention=attention, + attention=vanilla_attention, attention_dropout=dropout, ), ), @@ -163,7 +162,23 @@ class AttentionAE(nn.Module): ###################################################################### -class FunctionalAttentionAE(AttentionAE): +class WithMaskedResidual(nn.Module): + def __init__(self, masker, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + self.masker = masker + self.mask = None + + def forward(self, x): + if self.mask is None: + self.mask = self.masker(x) + return self.mask * x + self.f(x) + + +###################################################################### + + +class FunctionalAttentionAE(nn.Module): def __init__( self, vocabulary_size, @@ -176,6 +191,21 @@ class FunctionalAttentionAE(AttentionAE): dropout=0.0, len_max=1e5, ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.nb_work_tokens = nb_work_tokens + + self.embedding = nn.Sequential( + nn.Embedding(2 * vocabulary_size, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + def no_peek_attention(q, k, v): a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) n = self.nb_work_tokens @@ -186,23 +216,54 @@ class FunctionalAttentionAE(AttentionAE): y = torch.einsum("nhts,nhsd->nhtd", a, v) return y - AttentionAE.__init__( - self, - vocabulary_size, - dim_model, - dim_keys, - dim_hidden, - nb_heads, - nb_blocks, - attention=no_peek_attention, - dropout=0.0, - len_max=1e5, - ) - self.nb_work_tokens = nb_work_tokens + def masker(x): + m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens + return m[None, :, None] + + for b in range(nb_blocks): + trunk_blocks += [ + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=no_peek_attention, + attention_dropout=dropout, + ), + ), + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) def forward(self, x): - x = torch.cat([x.new_zeros(x.size(0), self.nb_work_tokens), x], dim=1) - return AttentionAE.forward(self, x)[:, self.nb_work_tokens :] + x = self.embedding(x) + x = F.pad(x, (0, 0, self.nb_work_tokens, 0)) + x = self.positional_encoding(x) + x = self.trunk(x) + x = F.pad(x, (0, 0, -self.nb_work_tokens, 0)) + x = self.readout(x) + return x ###################################################################### diff --git a/grids.py b/grids.py index 197eb5a..e5890ca 100755 --- a/grids.py +++ b/grids.py @@ -134,16 +134,16 @@ def grow_islands(nb, height, width, nb_seeds, nb_iterations): class Grids(problem.Problem): - grid_gray = 64 - thickness = 1 - background_gray = 255 - dots = False - - # grid_gray=240 - # thickness=1 - # background_gray=240 + # grid_gray = 64 + # thickness = 1 + # background_gray = 255 # dots = False + grid_gray = 240 + thickness = 0 + background_gray = 240 + dots = False + # grid_gray = 192 # thickness = 0 # background_gray = 255 @@ -288,7 +288,7 @@ class Grids(problem.Problem): def vocabulary_size(self): warnings.warn("hack +4 to keep the vocabulary size unchanged", RuntimeWarning) - return self.nb_colors + 4 + return self.nb_colors def grid2img(self, x, scale=15, grids=True): m = torch.logical_and(x >= 0, x < self.nb_colors).long() @@ -369,6 +369,7 @@ class Grids(problem.Problem): grids=True, margin=12, delta=False, + delta_highlight=False, ): quizzes = quizzes.to("cpu") @@ -422,6 +423,10 @@ class Grids(problem.Problem): self.grid2img(f_B, grids=grids), frame[None, :], thickness=thickness ) + if delta_highlight: + q = (img_B == img_f_B).min(dim=1, keepdim=True).values.long() + img_f_B = q * (img_f_B // 4 + 192) + (1 - q) * img_f_B + # predicted_parts Nx4 # correct_parts Nx4 @@ -1847,6 +1852,7 @@ if __name__ == "__main__": "/tmp", t.__name__ + ".png", w_quizzes, + delta=True, # grids=False # comments=[f"{t.__name__} #{k}" for k in range(w_quizzes.size(0))], ) diff --git a/main.py b/main.py index 06dfc5e..10e6bc0 100755 --- a/main.py +++ b/main.py @@ -63,6 +63,8 @@ parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5) # ---------------------------------- +parser.add_argument("--model_type", type=str, default="standard") + parser.add_argument("--model", type=str, default="37M") parser.add_argument("--dim_model", type=int, default=None) @@ -843,9 +845,16 @@ log_string(f"vocabulary_size {vocabulary_size}") models = [] +if args.model_type == "standard": + model_constructor = attae.AttentionAE +elif args.model_type == "functional": + model_constructor = attae.FunctionalAttentionAE +else: + raise ValueError(f"Unknown model type {args.model_type}") + + for i in range(args.nb_models): - # model = attae.FunctionalAttentionAE( - model = attae.AttentionAE( + model = model_constructor( vocabulary_size=vocabulary_size * 2, dim_model=args.dim_model, dim_keys=args.dim_keys,