Update.
authorFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:04:41 +0000 (07:04 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Fri, 23 Aug 2024 05:04:41 +0000 (07:04 +0200)
grids.py
main.py
quiz_machine.py

index 35b3cff..98a0581 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -148,30 +148,30 @@ class Grids(problem.Problem):
         ("gray", [128, 128, 128]),
     ]
 
-    def check_structure(self, quizzes, struct):
+    def check_order(self, quizzes, quad_order):
         S = self.height * self.width
 
         return (
-            (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]])
-            & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]])
-            & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]])
-            & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]])
+            (quizzes[:, 0 * (S + 1)] == self.l2tok[quad_order[0]])
+            & (quizzes[:, 1 * (S + 1)] == self.l2tok[quad_order[1]])
+            & (quizzes[:, 2 * (S + 1)] == self.l2tok[quad_order[2]])
+            & (quizzes[:, 3 * (S + 1)] == self.l2tok[quad_order[3]])
         ).all()
 
-    def get_structure(self, quizzes):
+    def get_order(self, quizzes):
         S = self.height * self.width
-        struct = tuple(
+        quad_order = tuple(
             self.tok2l[n.item()]
             for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0]
         )
-        self.check_structure(quizzes, struct)
-        return struct
+        self.check_order(quizzes, quad_order)
+        return quad_order
 
-    def inject_noise(self, quizzes, noise, struct, quad):
-        assert self.check_structure(quizzes, struct=struct)
+    def inject_noise(self, quizzes, noise, quad_order, quad_noise):
+        assert self.check_order(quizzes, quad_order=quad_order)
         S = self.height * self.width
 
-        mask = torch.tensor(quad, device=quizzes.device)
+        mask = torch.tensor(quad_noise, device=quizzes.device)
         mask = mask[None, :, None].expand(1, 4, S + 1).clone()
         mask[:, :, 0] = 0
         mask = mask.reshape(1, -1).expand_as(quizzes)
@@ -182,20 +182,20 @@ class Grids(problem.Problem):
         return quizzes
 
     # What a mess
-    def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+    def reconfigure(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
         if torch.is_tensor(quizzes):
-            return self.reconfigure([quizzes], struct=struct)[0]
+            return self.reconfigure([quizzes], quad_order=quad_order)[0]
 
         S = self.height * self.width
         result = [x.new(x.size()) for x in quizzes]
 
-        struct_from = self.get_structure(quizzes[0][:1])
-        i = self.indices_select(quizzes[0], struct_from)
+        quad_order_from = self.get_order(quizzes[0][:1])
+        i = self.indices_select(quizzes[0], quad_order_from)
 
-        sf = dict((l, n) for n, l in enumerate(struct_from))
+        sf = dict((l, n) for n, l in enumerate(quad_order_from))
 
         for q in range(4):
-            k = sf[struct[q]]
+            k = sf[quad_order[q]]
             for x, y in zip(quizzes, result):
                 l = x.size(1) // 4
                 y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l]
@@ -204,7 +204,7 @@ class Grids(problem.Problem):
 
         if j.any():
             for z, y in zip(
-                self.reconfigure([x[j] for x in quizzes], struct=struct), result
+                self.reconfigure([x[j] for x in quizzes], quad_order=quad_order), result
             ):
                 y[j] = z
 
@@ -212,36 +212,36 @@ class Grids(problem.Problem):
 
     def trivial(self, quizzes):
         S = self.height * self.width
-        assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B"))
+        assert self.check_order(quizzes, quad_order=("A", "f_A", "B", "f_B"))
         a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
         return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
             dim=1
         ).values
 
     def make_quiz_mask(
-        self, quizzes, struct=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1)
+        self, quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=(0, 0, 0, 1)
     ):
-        assert self.check_structure(quizzes, struct)
+        assert self.check_order(quizzes, quad_order)
 
         ar_mask = quizzes.new_zeros(quizzes.size())
 
         S = self.height * self.width
         a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
-        a[:, 0, :] = quad[0]
-        a[:, 1, :] = quad[1]
-        a[:, 2, :] = quad[2]
-        a[:, 3, :] = quad[3]
+        a[:, 0, :] = quad_mask[0]
+        a[:, 1, :] = quad_mask[1]
+        a[:, 2, :] = quad_mask[2]
+        a[:, 3, :] = quad_mask[3]
 
         return ar_mask
 
-    def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+    def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
         q = quizzes.reshape(quizzes.size(0), 4, S + 1)
         return (
-            (q[:, 0, 0] == self.l2tok[struct[0]])
-            & (q[:, 1, 0] == self.l2tok[struct[1]])
-            & (q[:, 2, 0] == self.l2tok[struct[2]])
-            & (q[:, 3, 0] == self.l2tok[struct[3]])
+            (q[:, 0, 0] == self.l2tok[quad_order[0]])
+            & (q[:, 1, 0] == self.l2tok[quad_order[1]])
+            & (q[:, 2, 0] == self.l2tok[quad_order[2]])
+            & (q[:, 3, 0] == self.l2tok[quad_order[3]])
         )
 
     def __init__(
@@ -1707,13 +1707,13 @@ class Grids(problem.Problem):
 
     ######################################################################
 
-    def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")):
+    def create_empty_quizzes(self, nb, quad_order=("A", "f_A", "B", "f_B")):
         S = self.height * self.width
         quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
-        quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]]
-        quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]]
-        quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]]
-        quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]]
+        quizzes[:, 0 * (S + 1)] = self.l2tok[quad_order[0]]
+        quizzes[:, 1 * (S + 1)] = self.l2tok[quad_order[1]]
+        quizzes[:, 2 * (S + 1)] = self.l2tok[quad_order[2]]
+        quizzes[:, 3 * (S + 1)] = self.l2tok[quad_order[3]]
 
         return quizzes
 
@@ -1764,10 +1764,10 @@ if __name__ == "__main__":
     # nb = 5
     # quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
     # print(quizzes)
-    # print(grids.get_structure(quizzes))
+    # print(grids.get_order(quizzes))
     # quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
     # print("DEBUG2", quizzes)
-    # print(grids.get_structure(quizzes))
+    # print(grids.get_order(quizzes))
     # print(quizzes)
 
     # i = torch.rand(quizzes.size(0)) < 0.5
@@ -1778,8 +1778,8 @@ if __name__ == "__main__":
 
     # print(
     # i.equal(j),
-    # grids.get_structure(quizzes[j]),
-    # grids.get_structure(quizzes[j == False]),
+    # grids.get_order(quizzes[j]),
+    # grids.get_order(quizzes[j == False]),
     # )
 
     #   exit(0)
diff --git a/main.py b/main.py
index 2a35209..a65d893 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -541,7 +541,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
             (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict(
                 model,
                 solved_c_quizzes[:, model.id],
-                struct=("A", "f_A", "B", "f_B"),
+                quad_orders=("A", "f_A", "B", "f_B"),
                 quad=(0, 0, 0, 1),
             )
 
@@ -821,6 +821,33 @@ class MyAttentionVAE(nn.Module):
         return bs
 
 
+def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
+    full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
+        args.nb_train_samples, data_structures=data_structures
+    )
+
+    src = zip(
+        full_input.split(args.batch_size),
+        full_mask_generate.split(args.batch_size),
+        full_mask_loss.split(args.batch_size),
+    )
+
+    if desc is not None:
+        src = tqdm.tqdm(
+            src,
+            dynamic_ncols=True,
+            desc=desc,
+            total=full_input.size(0) // args.batch_size,
+        )
+
+    for input, mask_generate, mask_loss in src:
+        yield (
+            input.to(local_device),
+            mask_generate.to(local_device),
+            mask_loss.to(local_device),
+        )
+
+
 def test_ae(local_device=main_device):
     model = MyAttentionVAE(
         vocabulary_size=vocabulary_size,
@@ -832,6 +859,14 @@ def test_ae(local_device=main_device):
         dropout=args.dropout,
     ).to(main_device)
 
+    data_structures = [
+        (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+        (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (1, 1, 1, 1)),
+        (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (1, 1, 1, 1)),
+        (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 1, 1, 1)),
+        (("A", "f_A", "B", "f_B"), (1, 1, 1, 0), (0, 0, 0, 0), (1, 1, 1, 1)),
+    ]
+
     model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
     model.to(local_device).train()
@@ -847,26 +882,13 @@ def test_ae(local_device=main_device):
         model.train()
         nb_train_samples, acc_train_loss = 0, 0.0
 
-        full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
-            args.nb_train_samples
-        )
-
-        src = zip(
-            full_input.split(args.batch_size),
-            full_mask_generate.split(args.batch_size),
-            full_mask_loss.split(args.batch_size),
-        )
-
-        for input, mask_generate, mask_loss in tqdm.tqdm(
-            src,
-            dynamic_ncols=True,
-            desc="training",
-            total=full_input.size(0) // args.batch_size,
+        for input, mask_generate, mask_loss in ae_batches(
+            quiz_machine,
+            args.nb_train_samples,
+            data_structures,
+            local_device,
+            "training",
         ):
-            input = input.to(local_device)
-            mask_generate = mask_generate.to(local_device)
-            mask_loss = mask_loss.to(local_device)
-
             if nb_train_samples % args.batch_size == 0:
                 model.optimizer.zero_grad()
 
@@ -911,26 +933,13 @@ def test_ae(local_device=main_device):
 
             nb_test_samples, acc_test_loss = 0, 0.0
 
-            full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
-                args.nb_test_samples
-            )
-
-            src = zip(
-                full_input.split(args.batch_size),
-                full_mask_generate.split(args.batch_size),
-                full_mask_loss.split(args.batch_size),
-            )
-
-            for input, mask_generate, mask_loss in tqdm.tqdm(
-                src,
-                dynamic_ncols=True,
-                desc="testing",
-                total=full_input.size(0) // args.batch_size,
+            for input, mask_generate, mask_loss in ae_batches(
+                quiz_machine,
+                args.nb_test_samples,
+                data_structures,
+                local_device,
+                "test",
             ):
-                input = input.to(local_device)
-                mask_generate = mask_generate.to(local_device)
-                mask_loss = mask_loss.to(local_device)
-
                 targets = input
 
                 mask_noise = (mask_generate != 0) & (
@@ -960,10 +969,10 @@ def test_ae(local_device=main_device):
 
             log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
 
-            input, mask_generate, mask_loss = quiz_machine.data_input(128)
-            input = input.to(local_device)
-            mask_generate = mask_generate.to(local_device)
-            mask_loss = mask_loss.to(local_device)
+            input, mask_generate, mask_loss = next(
+                ae_batches(quiz_machine, 128, data_structures, local_device)
+            )
+
             targets = input
 
             pred_result = None
@@ -1013,8 +1022,10 @@ def test_ae(local_device=main_device):
             nb = 0
 
             # We consider all the configurations that we train for
-            for struct, quad_generate, _, _ in quiz_machine.test_structures:
-                i = quiz_machine.problem.indices_select(quizzes=input, struct=struct)
+            for quad_order, quad_generate, _, _ in quiz_machine.test_structures:
+                i = quiz_machine.problem.indices_select(
+                    quizzes=input, quad_order=quad_order
+                )
                 nb += i.long().sum()
 
                 predicted_parts[i] = torch.tensor(quad_generate, device=result.device)[
index bea0d78..0f13964 100755 (executable)
@@ -175,39 +175,46 @@ class QuizMachine:
         quizzes = quizzes[i]
 
         self.randomize_configuations_inplace(
-            quizzes, structs=[s for s, _, _, _ in data_structures]
+            quizzes, quad_orders=[s for s, _, _, _ in data_structures]
         )
 
         quiz_mask_generate = quizzes.new_full(quizzes.size(), 1)
         quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
 
-        for struct, quad_generate, quad_noise, quad_loss in data_structures:
-            i = self.problem.indices_select(quizzes=quizzes, struct=struct)
+        for quad_order, quad_generate, quad_noise, quad_loss in data_structures:
+            i = self.problem.indices_select(quizzes=quizzes, quad_order=quad_order)
             if i.any():
                 if self.prompt_noise > 0.0:
                     quizzes[i] = self.problem.inject_noise(
-                        quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
+                        quizzes[i],
+                        self.prompt_noise,
+                        quad_order=quad_order,
+                        quad_noise=quad_noise,
                     )
                 quiz_mask_generate[i] = self.make_quiz_mask(
-                    quizzes=quizzes[i], struct=struct, quad=quad_generate
+                    quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
                 )
                 quiz_mask_loss[i] = self.make_quiz_mask(
-                    quizzes=quizzes[i], struct=struct, quad=quad_loss
+                    quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss
                 )
 
         return quizzes, quiz_mask_generate, quiz_mask_loss
 
     ######################################################################
 
-    def make_quiz_mask(self, quizzes, struct, quad):
-        assert struct in [s for s, _, _, _ in self.train_structures]
-        return self.problem.make_quiz_mask(quizzes, struct=struct, quad=quad)
+    def make_quiz_mask(self, quizzes, quad_order, quad_mask):
+        assert quad_order in [s for s, _, _, _ in self.train_structures]
+        return self.problem.make_quiz_mask(
+            quizzes, quad_order=quad_order, quad_mask=quad_mask
+        )
 
     ######################################################################
 
-    def predict(self, model, quizzes, struct, quad):
+    def predict(self, model, quizzes, quad_order, quad_mask):
         quizzes = quizzes.to(self.device)
-        ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, quad=quad)
+        ar_mask = self.make_quiz_mask(
+            quizzes=quizzes, quad_order=quad_order, quad_mask=quad_mask
+        )
         result = quizzes * (1 - ar_mask)
 
         seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
@@ -239,11 +246,11 @@ class QuizMachine:
         nb = 0
 
         # We consider all the configurations that we train for
-        for struct, quad_generate, _, _ in self.test_structures:
-            i = self.problem.indices_select(quizzes=input, struct=struct)
+        for quad_order, quad_generate, _, _ in self.test_structures:
+            i = self.problem.indices_select(quizzes=input, quad_order=quad_order)
             nb += i.long().sum()
             result[i], correct[i], _ = self.predict(
-                model=model, quizzes=input[i], struct=struct, quad=quad_generate
+                model=model, quizzes=input[i], quad_order=quad_order, quad=quad_generate
             )
 
             predicted_parts[i] = torch.tensor(quad_generate, device=self.device)[
@@ -282,11 +289,11 @@ class QuizMachine:
 
     ######################################################################
 
-    def randomize_configuations_inplace(self, quizzes, structs):
-        r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device)
-        for c in range(len(structs)):
+    def randomize_configuations_inplace(self, quizzes, quad_orders):
+        r = torch.randint(len(quad_orders), (quizzes.size(0),), device=quizzes.device)
+        for c in range(len(quad_orders)):
             quizzes[r == c] = self.problem.reconfigure(
-                quizzes[r == c], struct=structs[c]
+                quizzes[r == c], quad_order=quad_orders[c]
             )
 
     ######################################################################
@@ -310,7 +317,7 @@ class QuizMachine:
         self,
         model,
         c_quizzes,
-        struct,
+        quad_order,
         quad_loss,
         quad_noise=None,
         temperature=1.0,
@@ -319,7 +326,7 @@ class QuizMachine:
         if device is None:
             device = self.device
 
-        c_quizzes = self.problem.reconfigure(c_quizzes, struct)
+        c_quizzes = self.problem.reconfigure(c_quizzes, quad_order)
 
         seq_logprobas = torch.zeros(
             c_quizzes.size(0),
@@ -328,7 +335,7 @@ class QuizMachine:
 
         # if self.prompt_noise > 0.0 and quad_noise is not None:
         # c_quizzes = self.problem.inject_noise(
-        # c_quizzes, self.prompt_noise, struct=struct, quad=quad_noise
+        # c_quizzes, self.prompt_noise, quad_order=quad_order, quad_noise=quad_noise
         # )
 
         with torch.autograd.no_grad():
@@ -341,7 +348,7 @@ class QuizMachine:
             ):
                 input = input.to(device)
                 quiz_mask_loss = self.make_quiz_mask(
-                    input, struct=struct, quad=quad_loss
+                    input, quad_order=quad_order, quad_mask=quad_loss
                 )
                 output = model(mygpt.BracketedSequence(input)).x / temperature
                 l[...] = (
@@ -361,13 +368,13 @@ class QuizMachine:
         c_quizzes = None
 
         for n_step, setup in enumerate(procedure):
-            struct, quad_generate, model_modifier = setup
+            quad_order, quad_generate, model_modifier = setup
             if c_quizzes is None:
-                c_quizzes = self.problem.create_empty_quizzes(nb, struct)
+                c_quizzes = self.problem.create_empty_quizzes(nb, quad_order)
                 c_quizzes = c_quizzes.to(self.device)
-            elif struct != pred_struct:
-                c_quizzes = self.problem.reconfigure(c_quizzes, struct)
-            pred_struct = struct
+            elif quad_order != pred_quad_order:
+                c_quizzes = self.problem.reconfigure(c_quizzes, quad_order)
+            pred_quad_order = quad_order
 
             if model_modifier is not None:
                 model_modifier(model_for_generation)
@@ -375,7 +382,9 @@ class QuizMachine:
             self.autoregression(
                 model=model_for_generation,
                 input=c_quizzes,
-                ar_mask=self.make_quiz_mask(c_quizzes, struct, quad_generate),
+                ar_mask=self.make_quiz_mask(
+                    quizzes=c_quizzes, quad_order=quad_order, quad_mask=quad_generate
+                ),
                 seq_logprobas=seq_logprobas,
                 progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}",
             )