Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 20:51:44 +0000 (22:51 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 8 Sep 2024 20:51:44 +0000 (22:51 +0200)
attae.py
main.py

index e9e4bff..069772b 100755 (executable)
--- a/attae.py
+++ b/attae.py
@@ -45,6 +45,21 @@ class WithResidual(nn.Module):
 ######################################################################
 
 
+class vanilla_attention(q, k, v):
+    a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3))
+    a = a.softmax(dim=3)
+    y = torch.einsum("nhts,nhsd->nhtd", a, v)
+
+    # y = flex_attention(q, k, v, score_mod=noop)
+
+    y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
+
+    return y
+
+
+vanilla_attention = torch.compille(vanilla_attention)
+
+
 class MHAttention(nn.Module):
     def __init__(
         self,
@@ -72,10 +87,14 @@ class MHAttention(nn.Module):
             x_kv = x_q
 
         q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q)
-        k = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_k)
-        v = torch.einsum("ntc,hdc->nhtd", x_kv, self.w_v)
+        k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k)
+        v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v)
+
+        def noop(score, b, h, q_idx, kv_idx):
+            return score
 
-        y = flex_attention(q, k, v)
+        y = vanilla_attention(q, k, v, score_mod=noop)
+        # y = flex_attention(q, k, v, score_mod=noop)
 
         y = torch.einsum("nhtd,hdc->ntc", y, self.w_o)
 
diff --git a/main.py b/main.py
index 9285337..301c4f8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -16,8 +16,6 @@ from torch.nn import functional as F
 
 import ffutils
 
-import attae
-
 import mygpt
 import sky, grids, quiz_machine
 
@@ -775,7 +773,7 @@ def ae_generate(model, x_0, mask_generate, nb_iterations_max=50):
 ######################################################################
 
 
-def model_ae_proba_solutions(model, input, log_proba=False):
+def model_ae_proba_solutions(model, input, log_probas=False, reduce=True):
     record = []
 
     for x_0 in input.split(args.batch_size):
@@ -791,12 +789,16 @@ def model_ae_proba_solutions(model, input, log_proba=False):
             loss_per_token = F.cross_entropy(
                 logits.transpose(1, 2), x_0, reduction="none"
             )
-            loss += (loss_per_token * mask_generate).sum(dim=1)
+            if reduce:
+                loss += (loss_per_token * mask_generate).sum(dim=1)
+            else:
+                loss += loss_per_token * mask_generate
+
         record.append(loss)
 
     loss = torch.cat(record, dim=0)
 
-    if log_proba:
+    if log_probas:
         return -loss
     else:
         return (-loss).exp()
@@ -811,6 +813,7 @@ def model_ae_argmax_nb_mistakes(model, input):
             mask_generate = quiz_machine.make_quiz_mask(
                 quizzes=x_0, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
             )
+
             logits = logits_hat_x_0_from_random_iteration(
                 model, x_0, mask_generate, prompt_noise=args.prompt_noise
             )
@@ -996,11 +999,13 @@ def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_devi
 
 ######################################################################
 
+# import attae
+
 models = []
 
 for i in range(args.nb_models):
-    model = MyAttentionAE(
-    model = attae.AttentionAE(
+    model = MyAttentionAE(
+        # model = attae.AttentionAE(
         vocabulary_size=vocabulary_size,
         dim_model=args.dim_model,
         dim_keys=args.dim_keys,
@@ -1055,11 +1060,11 @@ def save_badness_statistics(
 ######################################################################
 
 
-def quiz_validation(models, c_quizzes, local_device):
+def quiz_validation_1(models, c_quizzes, local_device):
     nb_have_to_be_correct = args.nb_models // 2
-    nb_have_to_be_wrong = args.nb_models // 5
+    nb_have_to_be_wrong = 1
 
-    nb_runs = 3
+    nb_runs = 1
     nb_mistakes_to_be_wrong = 5
 
     record_wrong = []
@@ -1088,6 +1093,78 @@ def quiz_validation(models, c_quizzes, local_device):
     return to_keep, wrong
 
 
+def quiz_validation_2(models, c_quizzes, local_device):
+    nb_have_to_be_correct = 3
+    nb_have_to_be_wrong = 1
+    nb_runs = 3
+
+    record_wrong = []
+    nb_correct, nb_wrong = 0, 0
+
+    for i, model in enumerate(models):
+        assert i == model.id  # a bit of paranoia
+        model = copy.deepcopy(model).to(local_device).eval()
+        log_probas_max, log_probas_min = None, None
+        for _ in range(nb_runs):
+            log_probas = model_ae_proba_solutions(
+                model, c_quizzes, log_probas=True, reduce=False
+            )
+            log_probas_max = (
+                log_probas
+                if log_probas_max is None
+                else log_probas.maximum(log_probas_max)
+            )
+            log_probas_min = (
+                log_probas
+                if log_probas_min is None
+                else log_probas.minimum(log_probas_min)
+            )
+        probas = log_probas.sum(dim=1).exp()
+        correct = (log_probas_min.exp() <= 0.75).long().sum(dim=1) == 0
+        wrong = (log_probas_min.exp() <= 0.1).long().sum(dim=1) >= 3
+        record_wrong.append(wrong[:, None])
+        nb_correct += correct.long()
+        nb_wrong += wrong.long()
+
+    to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
+
+    wrong = torch.cat(record_wrong, dim=1)
+
+    return to_keep, wrong
+
+
+def quiz_validation(models, c_quizzes, local_device):
+    nb_have_to_be_correct = 3
+    nb_have_to_be_wrong = 1
+    nb_runs = 3
+
+    record_wrong = []
+    nb_correct, nb_wrong = 0, 0
+
+    for i, model in enumerate(models):
+        assert i == model.id  # a bit of paranoia
+        model = copy.deepcopy(model).to(local_device).eval()
+        log_probas = 0
+        for _ in range(nb_runs):
+            log_probas += model_ae_proba_solutions(
+                model, c_quizzes, log_probas=True, reduce=False
+            )
+        probas = log_probas.exp()
+        correct = (probas <= 0.75).long().sum(dim=1) == 0
+        wrong = ((probas <= 0.125).long().sum(dim=1) >= 5) & (
+            log_probas.sum(dim=1).div(nb_runs).exp() <= 0.5
+        )
+        record_wrong.append(wrong[:, None])
+        nb_correct += correct.long()
+        nb_wrong += wrong.long()
+
+    to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
+
+    wrong = torch.cat(record_wrong, dim=1)
+
+    return to_keep, wrong
+
+
 def generate_ae_c_quizzes(models, nb, local_device=main_device):
     # To be thread-safe we must make copies
 
@@ -1305,20 +1382,28 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
         start_time = time.perf_counter()
 
-        for gpu in gpus:
-            t = threading.Thread(
-                target=thread_generate_ae_c_quizzes,
-                daemon=True,
-                args=(models, nb_c_quizzes_to_generate, records, gpu),
-            )
+        if len(gpus) > 1:
+            for gpu in gpus:
+                t = threading.Thread(
+                    target=thread_generate_ae_c_quizzes,
+                    daemon=True,
+                    args=(models, nb_c_quizzes_to_generate, records, gpu),
+                )
 
-            # To get a different sequence between threads
-            log_string(f"dummy {torch.rand(1)}")
-            threads.append(t)
-            t.start()
+                # To get a different sequence between threads
+                log_string(f"dummy {torch.rand(1)}")
+                threads.append(t)
+                t.start()
 
-        for t in threads:
-            t.join()
+            for t in threads:
+                t.join()
+
+        else:
+            records.append(
+                generate_ae_c_quizzes(
+                    models, nb_c_quizzes_to_generate, records, gpus[0]
+                )
+            )
 
         time_c_quizzes = int(time.perf_counter() - start_time)
 
@@ -1350,25 +1435,36 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     start_time = time.perf_counter()
 
-    for gpu, model in zip(gpus, weakest_models):
+    if len(gpus) > 1:
+        for gpu, model in zip(gpus, weakest_models):
+            log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
+            if c_quizzes is None:
+                c_quizzes_for_this_model = None
+            else:
+                c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
+
+            t = threading.Thread(
+                target=one_ae_epoch,
+                daemon=True,
+                args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu),
+            )
+
+            threads.append(t)
+
+            t.start()
+
+        for t in threads:
+            t.join()
+
+    else:
+        model = weakest_models[0]
         log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
         if c_quizzes is None:
             c_quizzes_for_this_model = None
         else:
             c_quizzes_for_this_model = c_quizzes[agreements[:, model.id]]
 
-        t = threading.Thread(
-            target=one_ae_epoch,
-            daemon=True,
-            args=(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpu),
-        )
-
-        threads.append(t)
-
-        t.start()
-
-    for t in threads:
-        t.join()
+        one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes_for_this_model, gpus[0])
 
     time_train += int(time.perf_counter() - start_time)