Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 14 Sep 2024 07:57:14 +0000 (09:57 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 14 Sep 2024 07:57:14 +0000 (09:57 +0200)
attae.py
grids.py
main.py

index 05084ba..bb2d87f 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -106,7 +106,8 @@ class AttentionAE(nn.Module):
         assert dim_model % nb_heads == 0
 
         self.embedding = nn.Sequential(
-            nn.Embedding(2 * vocabulary_size, dim_model), nn.Dropout(dropout)
+            nn.Embedding(2 * vocabulary_size, dim_model),
+            nn.Dropout(dropout),
         )
 
         self.positional_encoding = VaswaniPositionalEncoding(len_max)
@@ -157,6 +158,38 @@ class AttentionAE(nn.Module):
 ######################################################################
 
 
+class MaskedAttentionAE(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+        self.core = AttentionAE(
+            vocabulary_size * 2,
+            dim_model,
+            dim_keys,
+            dim_hidden,
+            nb_heads,
+            nb_blocks,
+            dropout=0.0,
+            len_max=1e5,
+        )
+
+    def forward(self, x):
+        x = x[:, :, 0] * 2 + x[:, :, 1]
+        return self.core(x)
+
+
+######################################################################
+
+
 if __name__ == "__main__":
     model = AttentionAE(
         vocabulary_size=100,
index 054ba35..7754c43 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -406,6 +406,7 @@ class Grids(problem.Problem):
         comments=None,
         comment_height=48,
         nrow=4,
+        grids=True,
         margin=8,
         delta=False,
     ):
diff --git a/main.py b/main.py
index fb8f8cf..dede204 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -25,6 +25,8 @@ import threading, subprocess
 
 import torch.multiprocessing as mp
 
+torch.set_float32_matmul_precision("high")
+
 ######################################################################
 
 parser = argparse.ArgumentParser(
@@ -494,7 +496,6 @@ def ae_batches(
     local_device,
     c_quizzes=None,
     alien_quiz_machine=None,
-    nb_aliens=None,
     desc=None,
     batch_size=args.batch_size,
 ):
@@ -895,8 +896,8 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
         args.nb_train_samples,
         data_structures,
         local_device,
-        c_quizzes,
-        "training",
+        c_quizzes=c_quizzes,
+        desc="training",
     ):
         x_0 = x_0.to(local_device)
         mask_generate = mask_generate.to(local_device)
@@ -938,13 +939,13 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 
 ######################################################################
 
-import attae
+import attae
 
 models = []
 
 for i in range(args.nb_models):
-    model = MyAttentionAE(
-        # model = attae.AttentionAE(
+    model = MyAttentionAE(
+    model = attae.MaskedAttentionAE(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,
@@ -1307,6 +1308,7 @@ def multithread_execution(fun, arguments):
 def save_models(models, suffix=""):
     if suffix != "":
         suffix = "_" + suffix
+
     for model in models:
         filename = f"ae_{model.id:03d}{suffix}.pth"
         torch.save(
@@ -1392,16 +1394,12 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     start_time = time.perf_counter()
 
+    # None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
+
     multithread_execution(
         one_ae_epoch,
         [
-            (
-                model,
-                quiz_machine,
-                n_epoch,
-                None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
-                gpu,
-            )
+            (model, quiz_machine, n_epoch, c_quizzes, gpu)
             for model, gpu in zip(weakest_models, gpus)
         ],
     )