Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 13 Sep 2024 09:36:26 +0000 (11:36 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 13 Sep 2024 09:36:26 +0000 (11:36 +0200)
attae.py
main.py

index e201f60..05084ba 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -6,7 +6,8 @@ import torch
 
 from torch import nn
 from torch.nn import functional as F
-from torch.nn.attention.flex_attention import flex_attention
+
+# from torch.nn.attention.flex_attention import flex_attention
 
 ######################################################################
 
@@ -105,8 +106,7 @@ class AttentionAE(nn.Module):
         assert dim_model % nb_heads == 0
 
         self.embedding = nn.Sequential(
-            nn.Embedding(2 * vocabulary_size, dim_model),
-            nn.Dropout(dropout),
+            nn.Embedding(2 * vocabulary_size, dim_model), nn.Dropout(dropout)
         )
 
         self.positional_encoding = VaswaniPositionalEncoding(len_max)
diff --git a/main.py b/main.py
index e090f86..92a34f1 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -729,7 +729,7 @@ def logits_hat_x_0_from_random_iteration(model, x_0, mask_generate, prompt_noise
 def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None):
     noise = torch.randint(quiz_machine.problem.nb_colors, x_0.size(), device=x_0.device)
 
-    one_iteration_prediction = deterministic(mask_generate)[:, None]
+    single_iteration = deterministic(mask_generate)[:, None]
 
     if mask_hints is not None:
         mask_generate = mask_generate * (1 - mask_hints)
@@ -746,12 +746,11 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50, mask_hints=None
 
         hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample()
 
-        hat_x_t_minus_1 = one_iteration_prediction * hat_x_0 + (
-            1 - one_iteration_prediction
+        hat_x_t_minus_1 = single_iteration * hat_x_0 + (
+            1 - single_iteration
         ) * sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t)
 
         if hat_x_t_minus_1.equal(x_t):
-            # log_string(f"exit after {it+1} iterations")
             break
         else:
             changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values
@@ -794,54 +793,6 @@ def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
         return (-loss).exp()
 
 
-def model_ae_argmax_nb_mistakes(model, input):
-    record = []
-
-    for x_0 in input.split(args.batch_size):
-        nb_mistakes = 0
-        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
-            mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
-            )
-
-            logits = logits_hat_x_0_from_random_iteration(
-                model, x_0, mask_generate, prompt_noise=args.prompt_noise
-            )
-
-            predicted = logits.argmax(dim=-1)
-
-            nb_mistakes = nb_mistakes + (
-                mask_generate * predicted != mask_generate * x_0
-            ).long().sum(dim=1)
-
-        record.append(nb_mistakes)
-
-    return torch.cat(record, dim=0)
-
-
-######################################################################
-
-
-def model_ae_argmax_predictions(model, input):
-    result = input.clone()
-    # result[...] = 0
-
-    for r, x_0 in zip(result.split(args.batch_size), input.split(args.batch_size)):
-        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
-            mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
-            )
-            logits = logits_hat_x_0_from_random_iteration(
-                model, x_0, mask_generate, prompt_noise=args.prompt_noise
-            )
-
-            hat_x_0 = logits.argmax(dim=-1)
-
-            r[...] = (1 - mask_generate) * r + mask_generate * hat_x_0
-
-    return result
-
-
 ######################################################################
 
 
@@ -1013,39 +964,6 @@ for i in range(args.nb_models):
 ######################################################################
 
 
-def save_badness_statistics(
-    n_epoch, models, c_quizzes, suffix=None, local_device=main_device
-):
-    for model in models:
-        model.eval().to(local_device)
-    c_quizzes = c_quizzes.to(local_device)
-    with torch.autograd.no_grad():
-        log_probas = sum(
-            [model_ae_proba_solutions(model, c_quizzes) for model in models]
-        )
-        i = log_probas.sort().indices
-
-    suffix = "" if suffix is None else "_" + suffix
-
-    filename = f"culture_badness_{n_epoch:04d}{suffix}.png"
-
-    quiz_machine.problem.save_quizzes_as_image(
-        args.result_dir,
-        filename,
-        quizzes=c_quizzes[i[:128]],
-        # predicted_parts=predicted_parts,
-        # correct_parts=correct_parts,
-        # comments=comments,
-        delta=True,
-        nrow=8,
-    )
-
-    log_string(f"wrote {filename}")
-
-
-######################################################################
-
-
 def quiz_validation(
     models,
     c_quizzes,