Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 16:06:38 +0000 (18:06 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 22 Aug 2024 16:06:38 +0000 (18:06 +0200)
grids.py
main.py
quiz_machine.py

index 0564f3b..35b3cff 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -167,11 +167,11 @@ class Grids(problem.Problem):
         self.check_structure(quizzes, struct)
         return struct
 
-    def inject_noise(self, quizzes, noise, struct, mask):
+    def inject_noise(self, quizzes, noise, struct, quad):
         assert self.check_structure(quizzes, struct=struct)
         S = self.height * self.width
 
-        mask = torch.tensor(mask, device=quizzes.device)
+        mask = torch.tensor(quad, device=quizzes.device)
         mask = mask[None, :, None].expand(1, 4, S + 1).clone()
         mask[:, :, 0] = 0
         mask = mask.reshape(1, -1).expand_as(quizzes)
@@ -219,7 +219,7 @@ class Grids(problem.Problem):
         ).values
 
     def make_quiz_mask(
-        self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
+        self, quizzes, struct=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1)
     ):
         assert self.check_structure(quizzes, struct)
 
@@ -227,10 +227,10 @@ class Grids(problem.Problem):
 
         S = self.height * self.width
         a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
-        a[:, 0, :] = mask[0]
-        a[:, 1, :] = mask[1]
-        a[:, 2, :] = mask[2]
-        a[:, 3, :] = mask[3]
+        a[:, 0, :] = quad[0]
+        a[:, 1, :] = quad[1]
+        a[:, 2, :] = quad[2]
+        a[:, 3, :] = quad[3]
 
         return ar_mask
 
diff --git a/main.py b/main.py
index 148a917..35ba763 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -542,7 +542,7 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
                 model,
                 solved_c_quizzes[:, model.id],
                 struct=("A", "f_A", "B", "f_B"),
-                mask=(0, 0, 0, 1),
+                quad=(0, 0, 0, 1),
             )
 
             proba_own_solution[:, model.id] = model_proba_solutions(
@@ -740,6 +740,207 @@ class Thinker(nn.Module):
 ######################################################################
 
 
+from mygpt import (
+    WithResidual,
+    CacheWrapper,
+    AddPositionalEncoding,
+    QKVAttention,
+    BracketedSequence,
+)
+
+
+class MyAttentionVAE(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+        )
+
+        self.positional_encoding = AddPositionalEncoding(len_max)
+
+        trunk_blocks = []
+
+        for b in range(nb_blocks):
+            trunk_blocks += [
+                WithResidual(
+                    CacheWrapper(
+                        nn.LayerNorm((dim_model,)),
+                    ),
+                    QKVAttention(
+                        dim_in=dim_model,
+                        dim_qk=dim_keys,
+                        dim_v=dim_model // nb_heads,
+                        nb_heads=nb_heads,
+                        attention_dropout=dropout,
+                    ),
+                ),
+                WithResidual(
+                    CacheWrapper(
+                        nn.LayerNorm((dim_model,)),
+                        nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                        nn.ReLU(),
+                        nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                        nn.Dropout(dropout),
+                    ),
+                ),
+            ]
+
+        self.trunk = nn.Sequential(*trunk_blocks)
+
+        self.readout = CacheWrapper(
+            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+        )
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, bs):
+        bs = self.embedding(bs)
+        bs = self.positional_encoding(bs)
+        bs = self.trunk(bs)
+        bs = self.readout(bs)
+        return bs
+
+
+def test_ae(local_device=main_device):
+    model = MyAttentionVAE(
+        vocabulary_size=vocabulary_size,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        dropout=args.dropout,
+    ).to(main_device)
+
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+    model.to(local_device).train()
+    optimizer_to(model.optimizer, local_device)
+
+    if args.schedule_free:
+        model.optimizer.train()
+
+    for n_epoch in range(args.nb_epochs):
+        # ----------------------
+        # Train
+
+        model.train()
+        nb_train_samples, acc_train_loss = 0, 0.0
+
+        full_input, full_mask_loss = quiz_machine.data_input(args.nb_train_samples)
+
+        src = zip(
+            full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+        )
+
+        for input, mask_loss in tqdm.tqdm(
+            src,
+            dynamic_ncols=True,
+            desc="training",
+            total=full_input.size(0) // args.batch_size,
+        ):
+            input = input.to(local_device)
+            mask_loss = mask_loss.to(local_device)
+
+            if nb_train_samples % args.batch_size == 0:
+                model.optimizer.zero_grad()
+
+            targets = input
+            input = (mask_loss == 0).long() * input
+            output = model(mygpt.BracketedSequence(input)).x
+            loss = F.cross_entropy(output.transpose(1, 2), targets)
+            acc_train_loss += loss.item() * input.size(0)
+            nb_train_samples += input.size(0)
+            loss.backward()
+
+            if nb_train_samples % args.batch_size == 0:
+                model.optimizer.step()
+
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+        log_string(f"train_loss {n_epoch} model AE {acc_train_loss/nb_train_samples}")
+
+        # ----------------------
+        # Test
+
+        with torch.autograd.no_grad():
+            model.eval()
+
+            nb_test_samples, acc_test_loss = 0, 0.0
+
+            full_input, full_mask_loss = quiz_machine.data_input(args.nb_test_samples)
+
+            src = zip(
+                full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+            )
+
+            for input, mask_loss in tqdm.tqdm(
+                src,
+                dynamic_ncols=True,
+                desc="testing",
+                total=full_input.size(0) // args.batch_size,
+            ):
+                input = input.to(local_device)
+                mask_loss = mask_loss.to(local_device)
+                targets = input
+                input = (mask_loss == 0).long() * input
+                output = model(mygpt.BracketedSequence(input)).x
+                loss = F.cross_entropy(output.transpose(1, 2), targets)
+                acc_test_loss += loss.item() * input.size(0)
+                nb_test_samples += input.size(0)
+
+            log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
+
+            input, mask_loss = quiz_machine.data_input(128)
+            input = input.to(local_device)
+            mask_loss = mask_loss.to(local_device)
+            targets = input
+            input = (mask_loss == 0).long() * input
+            logits = model(mygpt.BracketedSequence(input)).x
+            dist = torch.distributions.categorical.Categorical(logits=logits)
+            result = dist.sample()
+            L = input.size(1) // 4
+            result[:, 0 * L] = input[:, 0 * L]
+            result[:, 1 * L] = input[:, 1 * L]
+            result[:, 2 * L] = input[:, 2 * L]
+            result[:, 3 * L] = input[:, 3 * L]
+            filename = f"prediction_ae_{n_epoch:04d}.png"
+
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir,
+                filename,
+                quizzes=result,
+            )
+
+            log_string(f"wrote {filename}")
+
+
+if args.test == "ae":
+    test_ae(local_device=main_device)
+    exit(0)
+
+######################################################################
+
+
 def create_models():
     models = []
 
@@ -1018,9 +1219,11 @@ if args.test == "entropy":
             procedure=c_quizzes_procedure,
         )
 
+        filename = f"test_{n_epoch:04d}.png"
+
         quiz_machine.problem.save_quizzes_as_image(
             args.result_dir,
-            f"test_{n_epoch:04d}.png",
+            filename,
             quizzes=input,
         )
 
@@ -1119,7 +1322,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
                 m = max(nb_c_quizzes_per_model)
 
-                if m >= args.nb_train_samples:
+                if m * args.c_quiz_multiplier >= args.nb_train_samples:
                     break
 
             model = models[nb_c_quizzes_per_model.index(m)]
index a0b007a..ceb527a 100755 (executable)
@@ -81,7 +81,7 @@ class QuizMachine:
         self.answer_len = None
         self.prompt_noise = prompt_noise
 
-        # struct, mask_generate, mask_noise, mask_loss
+        # struct, quad_generate, quad_noise, quad_loss
         self.train_structures = [
             (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
             (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
@@ -140,7 +140,7 @@ class QuizMachine:
 
     ######################################################################
 
-    def data_input(self, nb_samples, c_quiz_bags, c_quiz_multiplier=1):
+    def data_input(self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1):
         if len(c_quiz_bags) > 0:
             c_quizzes = torch.cat(c_quiz_bags, dim=0)
 
@@ -176,29 +176,29 @@ class QuizMachine:
         quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
 
         if self.prompt_noise > 0.0:
-            for struct, _, mask_noise, mask_loss in self.train_structures:
+            for struct, _, quad_noise, quad_loss in self.train_structures:
                 i = self.problem.indices_select(quizzes=quizzes, struct=struct)
                 if i.any():
                     quizzes[i] = self.problem.inject_noise(
-                        quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
+                        quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
                     )
                     quiz_mask_loss[i] = self.make_quiz_mask(
-                        quizzes=quizzes[i], struct=struct, mask=mask_loss
+                        quizzes=quizzes[i], struct=struct, quad=quad_loss
                     )
 
         return quizzes, quiz_mask_loss
 
     ######################################################################
 
-    def make_quiz_mask(self, quizzes, struct, mask):
+    def make_quiz_mask(self, quizzes, struct, quad):
         assert struct in [s for s, _, _, _ in self.train_structures]
-        return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask)
+        return self.problem.make_quiz_mask(quizzes, struct=struct, quad=quad)
 
     ######################################################################
 
-    def predict(self, model, quizzes, struct, mask):
+    def predict(self, model, quizzes, struct, quad):
         quizzes = quizzes.to(self.device)
-        ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
+        ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, quad=quad)
         result = quizzes * (1 - ar_mask)
 
         seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
@@ -230,14 +230,14 @@ class QuizMachine:
         nb = 0
 
         # We consider all the configurations that we train for
-        for struct, mask_generate, _, _ in self.test_structures:
+        for struct, quad_generate, _, _ in self.test_structures:
             i = self.problem.indices_select(quizzes=input, struct=struct)
             nb += i.long().sum()
             result[i], correct[i], _ = self.predict(
-                model=model, quizzes=input[i], struct=struct, mask=mask_generate
+                model=model, quizzes=input[i], struct=struct, quad=quad_generate
             )
 
-            predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[
+            predicted_parts[i] = torch.tensor(quad_generate, device=self.device)[
                 None, :
             ]
             solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
@@ -302,8 +302,8 @@ class QuizMachine:
         model,
         c_quizzes,
         struct,
-        mask_loss,
-        mask_noise=None,
+        quad_loss,
+        quad_noise=None,
         temperature=1.0,
         device=None,
     ):
@@ -317,9 +317,9 @@ class QuizMachine:
             device=device,
         )
 
-        # if self.prompt_noise > 0.0 and mask_noise is not None:
+        # if self.prompt_noise > 0.0 and quad_noise is not None:
         # c_quizzes = self.problem.inject_noise(
-        # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise
+        # c_quizzes, self.prompt_noise, struct=struct, quad=quad_noise
         # )
 
         with torch.autograd.no_grad():
@@ -332,7 +332,7 @@ class QuizMachine:
             ):
                 input = input.to(device)
                 quiz_mask_loss = self.make_quiz_mask(
-                    input, struct=struct, mask=mask_loss
+                    input, struct=struct, quad=quad_loss
                 )
                 output = model(mygpt.BracketedSequence(input)).x / temperature
                 l[...] = (
@@ -352,21 +352,21 @@ class QuizMachine:
         c_quizzes = None
 
         for n_step, setup in enumerate(procedure):
-            s, m, mt = setup
+            struct, quad_generate, model_modifier = setup
             if c_quizzes is None:
-                c_quizzes = self.problem.create_empty_quizzes(nb, s)
+                c_quizzes = self.problem.create_empty_quizzes(nb, struct)
                 c_quizzes = c_quizzes.to(self.device)
-            elif s != pred_s:
-                c_quizzes = self.problem.reconfigure(c_quizzes, s)
-            pred_s = s
+            elif struct != pred_struct:
+                c_quizzes = self.problem.reconfigure(c_quizzes, struct)
+            pred_struct = struct
 
-            if mt is not None:
-                mt(model_for_generation)
+            if model_modifier is not None:
+                model_modifier(model_for_generation)
 
             self.autoregression(
                 model=model_for_generation,
                 input=c_quizzes,
-                ar_mask=self.make_quiz_mask(c_quizzes, s, m),
+                ar_mask=self.make_quiz_mask(c_quizzes, struct, quad_generate),
                 seq_logprobas=seq_logprobas,
                 progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}",
             )
@@ -375,7 +375,9 @@ class QuizMachine:
 
             if recorder is not None:
                 x = c_quizzes.clone()
-                t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1)
+                t = torch.tensor(quad_generate, device=x.device)[None, :].expand(
+                    x.size(0), -1
+                )
                 recorder.append(
                     self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B"))
                 )