From 5241023b1f8379bc95cecb922be3d7a76165da9e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 14 Sep 2024 09:57:14 +0200 Subject: [PATCH] Update. --- attae.py | 35 ++++++++++++++++++++++++++++++++++- grids.py | 1 + main.py | 24 +++++++++++------------- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/attae.py b/attae.py index 05084ba..bb2d87f 100755 --- a/attae.py +++ b/attae.py @@ -106,7 +106,8 @@ class AttentionAE(nn.Module): assert dim_model % nb_heads == 0 self.embedding = nn.Sequential( - nn.Embedding(2 * vocabulary_size, dim_model), nn.Dropout(dropout) + nn.Embedding(2 * vocabulary_size, dim_model), + nn.Dropout(dropout), ) self.positional_encoding = VaswaniPositionalEncoding(len_max) @@ -157,6 +158,38 @@ class AttentionAE(nn.Module): ###################################################################### +class MaskedAttentionAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + self.core = AttentionAE( + vocabulary_size * 2, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ) + + def forward(self, x): + x = x[:, :, 0] * 2 + x[:, :, 1] + return self.core(x) + + +###################################################################### + + if __name__ == "__main__": model = AttentionAE( vocabulary_size=100, diff --git a/grids.py b/grids.py index 054ba35..7754c43 100755 --- a/grids.py +++ b/grids.py @@ -406,6 +406,7 @@ class Grids(problem.Problem): comments=None, comment_height=48, nrow=4, + grids=True, margin=8, delta=False, ): diff --git a/main.py b/main.py index fb8f8cf..dede204 100755 --- a/main.py +++ b/main.py @@ -25,6 +25,8 @@ import threading, subprocess import torch.multiprocessing as mp +torch.set_float32_matmul_precision("high") + ###################################################################### parser = argparse.ArgumentParser( @@ -494,7 +496,6 @@ def ae_batches( local_device, c_quizzes=None, alien_quiz_machine=None, - nb_aliens=None, desc=None, batch_size=args.batch_size, ): @@ -895,8 +896,8 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi args.nb_train_samples, data_structures, local_device, - c_quizzes, - "training", + c_quizzes=c_quizzes, + desc="training", ): x_0 = x_0.to(local_device) mask_generate = mask_generate.to(local_device) @@ -938,13 +939,13 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi ###################################################################### -# import attae +import attae models = [] for i in range(args.nb_models): - model = MyAttentionAE( - # model = attae.AttentionAE( + # model = MyAttentionAE( + model = attae.MaskedAttentionAE( vocabulary_size=vocabulary_size, dim_model=args.dim_model, dim_keys=args.dim_keys, @@ -1307,6 +1308,7 @@ def multithread_execution(fun, arguments): def save_models(models, suffix=""): if suffix != "": suffix = "_" + suffix + for model in models: filename = f"ae_{model.id:03d}{suffix}.pth" torch.save( @@ -1392,16 +1394,12 @@ for n_epoch in range(current_epoch, args.nb_epochs): start_time = time.perf_counter() + # None if c_quizzes is None else c_quizzes[agreements[:, model.id]], + multithread_execution( one_ae_epoch, [ - ( - model, - quiz_machine, - n_epoch, - None if c_quizzes is None else c_quizzes[agreements[:, model.id]], - gpu, - ) + (model, quiz_machine, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus) ], ) -- 2.39.5