Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 20 Aug 2024 20:37:04 +0000 (22:37 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 20 Aug 2024 20:37:04 +0000 (22:37 +0200)
grids.py
main.py
mygpt.py
quiz_machine.py

index 0564f3b..b12b4d6 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -167,11 +167,11 @@ class Grids(problem.Problem):
         self.check_structure(quizzes, struct)
         return struct
 
-    def inject_noise(self, quizzes, noise, struct, mask):
+    def inject_noise(self, quizzes, noise, struct, quad):
         assert self.check_structure(quizzes, struct=struct)
         S = self.height * self.width
 
-        mask = torch.tensor(mask, device=quizzes.device)
+        mask = torch.tensor(quad, device=quizzes.device)
         mask = mask[None, :, None].expand(1, 4, S + 1).clone()
         mask[:, :, 0] = 0
         mask = mask.reshape(1, -1).expand_as(quizzes)
@@ -219,20 +219,20 @@ class Grids(problem.Problem):
         ).values
 
     def make_quiz_mask(
-        self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
+        self, quizzes, struct=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1)
     ):
         assert self.check_structure(quizzes, struct)
 
-        ar_mask = quizzes.new_zeros(quizzes.size())
+        mask_ar = quizzes.new_zeros(quizzes.size())
 
         S = self.height * self.width
-        a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
-        a[:, 0, :] = mask[0]
-        a[:, 1, :] = mask[1]
-        a[:, 2, :] = mask[2]
-        a[:, 3, :] = mask[3]
+        a = mask_ar.reshape(mask_ar.size(0), 4, S + 1)[:, :, 1:]
+        a[:, 0, :] = quad[0]
+        a[:, 1, :] = quad[1]
+        a[:, 2, :] = quad[2]
+        a[:, 3, :] = quad[3]
 
-        return ar_mask
+        return mask_ar
 
     def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
diff --git a/main.py b/main.py
index 19c8394..8908613 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -214,6 +214,7 @@ if args.seed >= 0:
     torch.manual_seed(args.seed)
     if torch.cuda.is_available():
         torch.cuda.manual_seed_all(args.seed)
+        torch.set_float32_matmul_precision("high")
 
 ######################################################################
 
@@ -326,6 +327,14 @@ def optimizer_to(optim, device):
 ######################################################################
 
 
+def mask_ar_to_ranks(mask_ar):
+    a = (mask_ar < 2).long()
+    a = a.cumsum(dim=1) - a
+    b = ((mask_ar[:, :-1] == 2) & (mask_ar[:, 1:] != 2)).long().cumsum(dim=1)
+    a[:, 1:] += b
+    return a
+
+
 def run_tests(model, quiz_machine, local_device=main_device):
     with torch.autograd.no_grad():
         model.to(local_device).eval()
@@ -335,25 +344,30 @@ def run_tests(model, quiz_machine, local_device=main_device):
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
 
-        full_input, full_mask_loss = quiz_machine.data_input(
+        full_input, full_mask_ar, full_mask_loss = quiz_machine.data_input(
             args.nb_test_samples, test_c_quiz_bags
         )
 
         src = zip(
-            full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+            full_input.split(args.batch_size),
+            full_mask_ar.split(args.batch_size),
+            full_mask_loss.split(args.batch_size),
         )
 
-        for input, mask_loss in tqdm.tqdm(
+        for input, mask_ar, mask_loss in tqdm.tqdm(
             src,
             dynamic_ncols=True,
             desc="test",
             total=full_input.size(0) // args.batch_size,
         ):
             input = input.to(local_device)
+            mask_ar = mask_ar.to(local_device)
             mask_loss = mask_loss.to(local_device)
             targets = input
 
-            output = model(mygpt.BracketedSequence(input)).x
+            output = model(
+                mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar))
+            ).x
             loss_per_token = F.cross_entropy(
                 output.transpose(1, 2), targets, reduction="none"
             )
@@ -365,7 +379,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
 
-        input, _ = quiz_machine.data_input(1000, test_c_quiz_bags)
+        input, _, _ = quiz_machine.data_input(1000, test_c_quiz_bags)
 
         model.test_accuracy = quiz_machine.produce_results(
             n_epoch=n_epoch,
@@ -387,18 +401,24 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    full_input, full_mask_loss = quiz_machine.data_input(
+    full_input, full_mask_ar, full_mask_loss = quiz_machine.data_input(
         args.nb_train_samples, train_c_quiz_bags
     )
-    src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
 
-    for input, mask_loss in tqdm.tqdm(
+    src = zip(
+        full_input.split(args.batch_size),
+        full_mask_ar.split(args.batch_size),
+        full_mask_loss.split(args.batch_size),
+    )
+
+    for input, mask_ar, mask_loss in tqdm.tqdm(
         src,
         dynamic_ncols=True,
         desc="training",
         total=full_input.size(0) // args.batch_size,
     ):
         input = input.to(local_device)
+        mask_ar = mask_ar.to(local_device)
         mask_loss = mask_loss.to(local_device)
 
         if nb_train_samples % args.batch_size == 0:
@@ -406,7 +426,9 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
         targets = input
 
-        output = model(mygpt.BracketedSequence(input)).x
+        output = model(
+            mygpt.BracketedSequence(input, ranks=mask_ar_to_ranks(mask_ar))
+        ).x
 
         loss_per_token = F.cross_entropy(
             output.transpose(1, 2), targets, reduction="none"
@@ -456,10 +478,10 @@ def model_modifier_cold(model):
 
 
 c_quizzes_procedure = [
-    # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
     (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot),
-    # (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot),
-    # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
+    # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
+    # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
+    # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
 ]
 
 ######################################################################
@@ -580,6 +602,8 @@ def create_c_quizzes(
             procedure=c_quizzes_procedure,
         )
 
+        log_string(f"nb_generated_quizzes {c_quizzes.size(0)}")
+
         nb_generated += c_quizzes.size(0)
 
         # We discard the trivial ones, according to a criterion
@@ -589,6 +613,8 @@ def create_c_quizzes(
 
         c_quizzes = c_quizzes[to_keep]
 
+        log_string(f"nb_non_trivial_quizzes {c_quizzes.size(0)}")
+
         # Keep only the quizzes that the main model cannot solve
 
         solved_c_quizzes = c_quizzes.clone()
@@ -600,11 +626,15 @@ def create_c_quizzes(
             mask=(0, 0, 0, 1),
         )
 
+        log_string(f"nb_generated_quizzes {c_quizzes.size(0)}")
+
         main_probas = model_proba_solutions(main_model, main_solution)
-        log_string(f"main_probas {main_probas}")
+        log_string(f"main_probas {main_probas}")
         keep = main_probas < args.proba_not_understands
         c_quizzes = c_quizzes[keep]
 
+        log_string(f"nb_not_understood_quizzes {c_quizzes.size(0)}")
+
         # If there are some quizzes that the main model cannot solve,
         # pick the most confident solution
 
@@ -623,14 +653,17 @@ def create_c_quizzes(
                 )
 
                 probas = model_proba_solutions(model, solution)
-                log_string(f"probas {probas}")
+                log_string(f"probas {probas}")
                 keep = probas >= c_quizzes_proba
                 c_quizzes = solution[keep]
                 c_quizzes_proba[keep] = probas[keep]
 
             keep = c_quizzes_proba >= args.proba_understands
-            recorded.append(c_quizzes_proba[keep])
-            nb_validated += keep.long().sum()
+            c_quizzes = c_quizzes[keep]
+
+            log_string(f"nb_kept {c_quizzes.size(0)} total nb_validated {nb_validated}")
+            recorded.append(c_quizzes.clone().to("cpu"))
+            nb_validated += c_quizzes.size(0)
 
         duration = time.perf_counter() - start_time
 
@@ -683,7 +716,6 @@ for k in range(args.nb_gpts):
         dim_hidden=args.dim_hidden,
         nb_heads=args.nb_heads,
         nb_blocks=args.nb_blocks,
-        compute_attzero=compute_causal_attzero,
         dropout=args.dropout,
     ).to(main_device)
 
@@ -889,7 +921,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
                 dim_hidden=args.dim_hidden,
                 nb_heads=args.nb_heads,
                 nb_blocks=args.nb_blocks,
-                compute_attzero=compute_causal_attzero,
                 dropout=args.dropout,
             ).to(main_device)
             model.load_state_dict(new_model.state_dict())
index f716fe5..c69c899 100755 (executable)
--- a/mygpt.py
+++ b/mygpt.py
@@ -76,10 +76,11 @@ class RandomBypass(nn.Module):
 
 
 class BracketedSequence:
-    def __init__(self, x, first=None, nb=None):
+    def __init__(self, x, first=None, nb=None, ranks=None):
         self.x = x
         self.first = 0 if first is None else first
         self.nb = x.size(1) if nb is None else nb
+        self.ranks = ranks
 
     def slice(self):
         return self.x[:, self.first : self.first + self.nb]
@@ -104,7 +105,7 @@ class CacheWrapper(nn.Module):
         else:
             self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice())
 
-        return BracketedSequence(self.cache_y, bs.first, bs.nb)
+        return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.ranks)
 
 
 ##############################
@@ -116,7 +117,7 @@ class WithResidual(nn.Module):
         self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
 
     def forward(self, bs):
-        return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb)
+        return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb, bs.ranks)
 
 
 ##############################
@@ -147,7 +148,7 @@ class AddPositionalEncoding(nn.Module):
             bs.slice() + self.pe[bs.first : bs.first + bs.nb]
         )
 
-        return BracketedSequence(self.cache_y, bs.first, bs.nb)
+        return BracketedSequence(self.cache_y, bs.first, bs.nb, bs.ranks)
 
 
 ##############################
@@ -184,7 +185,6 @@ class QKVAttention(nn.Module):
         dim_qk,
         dim_v,
         nb_heads=1,
-        compute_attzero=None,
         attention_dropout=0.0,
     ):
         super().__init__()
@@ -192,7 +192,6 @@ class QKVAttention(nn.Module):
         def randw(*d):
             return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1]))
 
-        self.compute_attzero = compute_attzero
         self.attention_dropout = attention_dropout
         self.record_attention = False
 
@@ -234,16 +233,18 @@ class QKVAttention(nn.Module):
             "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_kv.first + bs_kv.nb]
         ) / math.sqrt(self.w_q.size(1))
 
-        if self.compute_attzero is not None:
-            if bs_q.first == 0:
-                self.cache_attzero = self.compute_attzero(
-                    torch.arange(x_q.size(1), device=q.device)[:, None],
-                    torch.arange(x_kv.size(1), device=q.device)[None, :],
-                )[None, None, :, :]
+        t = torch.arange(x_q.size(1), device=a.device)
+
+        if bs_q.ranks is not None:
             a = a.masked_fill(
-                self.cache_attzero[
-                    :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_kv.first + bs_kv.nb
-                ],
+                (
+                    bs_q.ranks[:, None, bs_q.first : bs_q.first + bs_q.nb, None]
+                    <= bs_kv.ranks[:, None, None, : bs_kv.first + bs_kv.nb]
+                )
+                & (
+                    t[None, None, bs_q.first : bs_q.first + bs_q.nb, None]
+                    != t[None, None, None, : bs_kv.first + bs_kv.nb]
+                ),
                 float("-inf"),
             )
 
@@ -297,7 +298,6 @@ class BlockSummarizer(nn.Module):
             dim_qk=dim_keys,
             dim_v=dim_model // nb_heads,
             nb_heads=nb_heads,
-            compute_attzero=compute_attzero,
             attention_dropout=dropout,
         )
 
@@ -310,7 +310,7 @@ class ShiftByOne(nn.Module):
         super().__init__()
 
     def forward(self, bs):
-        return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+        return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb, bs.ranks)
 
 
 class MyGPT(nn.Module):
@@ -322,7 +322,6 @@ class MyGPT(nn.Module):
         dim_hidden,
         nb_heads,
         nb_blocks,
-        compute_attzero=None,
         dropout=0.0,
         len_max=1e5,
     ):
@@ -354,7 +353,6 @@ class MyGPT(nn.Module):
                         dim_qk=dim_keys,
                         dim_v=dim_model // nb_heads,
                         nb_heads=nb_heads,
-                        compute_attzero=compute_attzero,
                         attention_dropout=dropout,
                     ),
                 ),
index 1acd7ad..d209a07 100755 (executable)
@@ -19,7 +19,7 @@ import threading
 
 ######################################################################
 
-# ar_mask is a tensor with 0s and 1s, of same shape as input, with
+# mask_ar is a tensor with 0s and 1s, of same shape as input, with
 # 1s where tokens should be generated. The others are kept
 # unchanged.
 
@@ -27,35 +27,40 @@ import threading
 def one_batch_masked_inplace_autoregression(
     model,
     input,
-    ar_mask,
+    mask_ar,
     acc_seq_logprobas,
     deterministic_synthesis=False,
 ):
     if input.size(0) == 0:
         return
 
-    to_generate = (ar_mask.sum(0) > 0).nonzero()
+    mask = (mask_ar > 0).long()
+    to_generate = (mask.sum(0) > 0).nonzero()
+
+    indices_1 = list(((mask_ar == 1).long().sum(0) > 0).nonzero()) + [mask.size(1)]
 
     if to_generate.min() > 0:
         model(
             BracketedSequence(input, 0, to_generate.min())
         )  # Needed to initialize the model's cache
-    for s in range(to_generate.min(), to_generate.max() + 1):
-        output = model(BracketedSequence(input, s, 1)).x
 
-        logits = output[:, s]
+    s = to_generate.min()
+
+    for s, u in zip(indices_1[:-1], indices_1[1:]):
+        logits = model(BracketedSequence(input, s, u - s)).x
 
         if deterministic_synthesis:
-            t_next = logits.argmax(-1)
+            t_next = logits.argmax(dim=2)
         else:
             dist = torch.distributions.categorical.Categorical(logits=logits)
             t_next = dist.sample()
 
-        all_n = torch.arange(t_next.size(0))
-
-        acc_seq_logprobas += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
+        acc_seq_logprobas += (
+            mask
+            * logits.log_softmax(dim=1).gather(dim=2, index=t_next[:, :, None])[:, :, 0]
+        ).sum(dim=1)
 
-        input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+        input[...] = mask * t_next + (1 - mask) * input
 
 
 ######################################################################
@@ -81,12 +86,12 @@ class QuizMachine:
         self.answer_len = None
         self.prompt_noise = prompt_noise
 
-        # struct, mask_generate, mask_noise, mask_loss
+        #  - struct, quad_generate, quad_noise, quad_loss
         self.train_structures = [
-            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+            (("A", "f_A", "B", "f_B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
+            (("f_A", "A", "f_B", "B"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
+            (("B", "f_B", "A", "f_A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
+            (("f_B", "B", "f_A", "A"), (0, 0, 0, 2), (0, 0, 1, 0), (1, 1, 0, 1)),
             (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
         ]
 
@@ -101,15 +106,15 @@ class QuizMachine:
         self,
         model,
         input,
-        ar_mask,
+        mask_ar,
         seq_logprobas,
         progress_bar_desc=None,
     ):
-        assert input.size() == ar_mask.size()
+        assert input.size() == mask_ar.size()
 
         batches = zip(
             input.split(self.batch_size),
-            ar_mask.split(self.batch_size),
+            mask_ar.split(self.batch_size),
             seq_logprobas.split(self.batch_size),
         )
 
@@ -125,11 +130,11 @@ class QuizMachine:
             t = model.training
             model.eval()
 
-            for input, ar_mask, seq_logprobas in batches:
+            for input, mask_ar, seq_logprobas in batches:
                 one_batch_masked_inplace_autoregression(
                     model=model,
                     input=input,
-                    ar_mask=ar_mask,
+                    mask_ar=mask_ar,
                     acc_seq_logprobas=seq_logprobas,
                     deterministic_synthesis=False,
                 )
@@ -158,40 +163,44 @@ class QuizMachine:
             quizzes, structs=[s for s, _, _, _ in self.train_structures]
         )
 
+        quiz_mask_ar = quizzes.new_full(quizzes.size(), 1)
         quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
 
-        if self.prompt_noise > 0.0:
-            for struct, _, mask_noise, mask_loss in self.train_structures:
-                i = self.problem.indices_select(quizzes=quizzes, struct=struct)
-                if i.any():
+        for struct, quad_ar, quad_noise, quad_loss in self.train_structures:
+            i = self.problem.indices_select(quizzes=quizzes, struct=struct)
+            if i.any():
+                if self.prompt_noise > 0.0:
                     quizzes[i] = self.problem.inject_noise(
-                        quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
-                    )
-                    quiz_mask_loss[i] = self.make_quiz_mask(
-                        quizzes=quizzes[i], struct=struct, mask=mask_loss
+                        quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
                     )
+                quiz_mask_ar[i] = self.make_quiz_mask(
+                    quizzes=quizzes[i], struct=struct, quad=quad_ar
+                )
+                quiz_mask_loss[i] = self.make_quiz_mask(
+                    quizzes=quizzes[i], struct=struct, quad=quad_loss
+                )
 
-        return quizzes, quiz_mask_loss
+        return quizzes, quiz_mask_ar, quiz_mask_loss
 
     ######################################################################
 
-    def make_quiz_mask(self, quizzes, struct, mask):
+    def make_quiz_mask(self, quizzes, struct, quad):
         assert struct in [s for s, _, _, _ in self.train_structures]
-        return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask)
+        return self.problem.make_quiz_mask(quizzes, struct=struct, quad=quad)
 
     ######################################################################
 
-    def predict(self, model, quizzes, struct, mask):
+    def predict(self, model, quizzes, struct, quad_ar):
         quizzes = quizzes.to(self.device)
-        ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
-        result = quizzes * (1 - ar_mask)
+        mask_ar = self.make_quiz_mask(quizzes=quizzes, struct=struct, quad=quad_ar)
+        result = quizzes * (mask_ar == 0).long()
 
         seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
 
         self.autoregression(
             model=model,
             input=result,
-            ar_mask=ar_mask,
+            mask_ar=mask_ar,
             seq_logprobas=seq_logprobas,
             progress_bar_desc="autoregression",
         )
@@ -215,16 +224,14 @@ class QuizMachine:
         nb = 0
 
         # We consider all the configurations that we train for
-        for struct, mask_generate, _, _ in self.test_structures:
+        for struct, quad_ar, _, _ in self.test_structures:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
             result[i], correct[i], _ = self.predict(
-                model=model, quizzes=input[i], struct=struct, mask=mask_generate
+                model=model, quizzes=input[i], struct=struct, quad=quad_ar
             )
 
-            predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[
-                None, :
-            ]
+            predicted_parts[i] = torch.tensor(quad_ar, device=self.device)[None, :]
             solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
             correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
 
@@ -351,7 +358,7 @@ class QuizMachine:
             self.autoregression(
                 model=model_for_generation,
                 input=c_quizzes,
-                ar_mask=self.make_quiz_mask(c_quizzes, s, m),
+                mask_ar=self.make_quiz_mask(c_quizzes, s, m),
                 seq_logprobas=seq_logprobas,
                 progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}",
             )