Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 06:22:26 +0000 (08:22 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 06:22:26 +0000 (08:22 +0200)
attae.py
main.py

index b4db3ab..a9bdeba 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -127,6 +127,7 @@ class AttentionAE(nn.Module):
                         dim_qk=dim_keys,
                         dim_v=dim_model // nb_heads,
                         nb_heads=nb_heads,
+                        attention=attention,
                         attention_dropout=dropout,
                     ),
                 ),
@@ -170,28 +171,20 @@ class FunctionalAttentionAE(AttentionAE):
         dim_keys,
         dim_hidden,
         nb_heads,
-        nb_work_tokens,
         nb_blocks,
+        nb_work_tokens=100,
         dropout=0.0,
         len_max=1e5,
     ):
-        # def functional_mask(b, h, q_idx, kv_idx):
-        # return (
-        # (q_idx < nb_work_tokens)
-        # | (kv_idx < nb_work_tokens)
-        # | ((q_idx - nb_work_tokens) // 200 == (kv_idx - nb_work_tokens) // 200)
-        # )
-
-        # block_mask = create_block_mask(
-        # functional_mask,
-        # B=None,
-        # H=None,
-        # Q_LEN=400 + nb_work_tokens,
-        # KV_LEN=400 + nb_work_tokens,
-        # )
-
-        # def functional_attention(q, k, v):
-        # return flex_attention(q, k, v, block_mask=block_mask)
+        def no_peek_attention(q, k, v):
+            a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+            n = self.nb_work_tokens
+            s = (q.size(2) - n) // 2
+            a[:, :, n + 0 * s : n + 1 * s, n + 0 * s : n + 1 * s] = float("-inf")
+            a[:, :, n + 1 * s : n + 2 * s, n + 1 * s : n + 2 * s] = float("-inf")
+            a = a.softmax(dim=3)
+            y = torch.einsum("nhts,nhsd->nhtd", a, v)
+            return y
 
         AttentionAE.__init__(
             self,
@@ -201,6 +194,7 @@ class FunctionalAttentionAE(AttentionAE):
             dim_hidden,
             nb_heads,
             nb_blocks,
+            attention=no_peek_attention,
             dropout=0.0,
             len_max=1e5,
         )
diff --git a/main.py b/main.py
index 16edcdc..d903693 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -198,6 +198,8 @@ if args.seed >= 0:
 
 
 def log_string(s):
+    """print the given string prefixed with a time stamps, and log it into log_file is not None"""
+
     t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
 
     if log_file is not None:
@@ -352,13 +354,10 @@ def add_noise_imt(imt_set):
 
 
 ######################################################################
-
-# IMT for input / masks / target
-
-# Generate a batch for prediction
+# Prediction
 
 
-def batch_for_prediction_imt(input):
+def samples_for_prediction_imt(input):
     nb = input.size(0)
     masks = input.new_zeros(input.size())
     u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
@@ -421,7 +420,7 @@ def predict_full(model, input, with_perturbations=False, local_device=main_devic
 ######################################################################
 
 
-def batch_for_generation_imt(input):
+def samples_for_generation_imt(input):
     nb = input.size(0)
     probs_iterations = 0.1 ** torch.linspace(
         0, 1, args.diffusion_nb_iterations, device=input.device
@@ -511,13 +510,13 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
 
     # Half of the samples train the prediction, and we inject noise in
     # all, and hints in half
-    b_p = batch_for_prediction_imt(q_p)
+    b_p = samples_for_prediction_imt(q_p)
     b_p = add_noise_imt(b_p)
     half = torch.rand(b_p.size(0)) < 0.5
     b_p[half] = add_hints_imt(b_p[half])
 
     # The other half are denoising examples for the generation
-    b_g = batch_for_generation_imt(q_g)
+    b_g = samples_for_generation_imt(q_g)
 
     imt_set = torch.cat([b_p, b_g])
     imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
@@ -590,7 +589,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     quizzes = quiz_machine.quiz_set(
         args.nb_test_samples, c_quizzes, args.c_quiz_multiplier
     )
-    imt_set = batch_for_prediction_imt(quizzes.to(local_device))
+    imt_set = samples_for_prediction_imt(quizzes.to(local_device))
     result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
     masks = imt_set[:, 1].to("cpu")
 
@@ -634,18 +633,8 @@ import attae
 models = []
 
 for i in range(args.nb_models):
-    # model = attae.FunctionalAttentionAE(
-    # vocabulary_size=vocabulary_size * 2,
-    # dim_model=args.dim_model,
-    # dim_keys=args.dim_keys,
-    # dim_hidden=args.dim_hidden,
-    # nb_heads=args.nb_heads,
-    # nb_blocks=args.nb_blocks,
-    # nb_work_tokens=10,
-    # dropout=args.dropout,
-    # )
-
-    model = attae.AttentionAE(
+    model = attae.FunctionalAttentionAE(
+        # model = attae.AttentionAE(
         vocabulary_size=vocabulary_size * 2,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,