Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 20:35:30 +0000 (22:35 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 20:35:30 +0000 (22:35 +0200)
main.py

diff --git a/main.py b/main.py
index 20acab3..8e06bb2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -384,8 +384,6 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     nb_train_samples, acc_train_loss = 0, 0.0
 
-    hard_w_quizzes = []
-
     full_input, full_mask_loss = quiz_machine.data_input(
         args.nb_train_samples, model.train_c_quiz_bags
     )
@@ -427,13 +425,6 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     run_tests(model, quiz_machine)
 
-    # threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
-    # threshold = threshold[threshold.size(0) // 2]
-
-    # model.hard_w_quizzes = torch.cat(
-    # [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
-    # )
-
     model.to(main_device)
     optimizer_to(model.optimizer, main_device)
 
@@ -441,28 +432,28 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 ######################################################################
 
 
-def model_transformer_hot(model):
+def model_modifier_hot(model):
     model.temperature = args.temperature_hot
     # model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2))
 
 
-def model_transformer_cold(model):
+def model_modifier_cold(model):
     model.temperature = args.temperature_cold
     # pass
 
 
 c_quizzes_procedure = [
-    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
-    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
-    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
-    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
+    (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
+    (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
+    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
+    (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
 ]
 
 ######################################################################
 
 
 def save_additional_results(model, models):
-    # Save generated quizzes with the successive steps
+    # Save generated quizzes with the successive generation steps
 
     recorder = []
 
@@ -660,6 +651,196 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
         )
 
 
+######################################################################
+
+from mygpt import (
+    WithResidual,
+    CacheWrapper,
+    AddPositionalEncoding,
+    QKVAttention,
+    BracketedSequence,
+)
+
+
+class Thinker(nn.Module):
+    def __init__(
+        self,
+        vocabulary_size,
+        dim_model,
+        dim_keys,
+        dim_hidden,
+        nb_heads,
+        nb_blocks,
+        f_len,
+        dropout=0.0,
+        len_max=1e5,
+    ):
+        super().__init__()
+
+        assert dim_model % nb_heads == 0
+
+        self.embedding = nn.Sequential(
+            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+            AddPositionalEncoding(len_max),
+        )
+
+        def trunk(depth):
+            trunk_blocks = []
+
+            for b in range(nb_blocks):
+                trunk_blocks += [
+                    WithResidual(
+                        CacheWrapper(
+                            nn.LayerNorm((dim_model,)),
+                        ),
+                        QKVAttention(
+                            dim_in=dim_model,
+                            dim_qk=dim_keys,
+                            dim_v=dim_model // nb_heads,
+                            nb_heads=nb_heads,
+                            attention_dropout=dropout,
+                        ),
+                    ),
+                    WithResidual(
+                        CacheWrapper(
+                            nn.LayerNorm((dim_model,)),
+                            nn.Linear(in_features=dim_model, out_features=dim_hidden),
+                            nn.ReLU(),
+                            nn.Linear(in_features=dim_hidden, out_features=dim_model),
+                            nn.Dropout(dropout),
+                        ),
+                    ),
+                ]
+
+            return nn.Sequential(*trunk_blocks)
+
+        self.bottom_trunk = trunk(nb_blocks // 2)
+
+        self.top_trunk = trunk(nb_blocks // 2)
+
+        self.readout = CacheWrapper(
+            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+        )
+
+        self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model))
+
+        with torch.no_grad():
+            for m in self.modules():
+                if isinstance(m, nn.Embedding):
+                    m.weight.normal_(mean=0, std=2e-2)
+                elif isinstance(m, nn.LayerNorm):
+                    m.bias.zero_()
+                    m.weight.fill_(1.0)
+
+    def forward(self, bs):
+        for m in self.modules():
+            m.loss = 0
+
+        L = bs.x.size(1) // 3
+
+        bs = self.embedding(bs)
+        A_fA = BracketedSequence(bs.x[:, : 2 * L])
+        B = BracketedSequence(bs.x[:, -L:])
+
+        bs = BracketedSequence(
+            torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1)
+        )
+        bs = self.bottom_trunk(bs)
+        bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1))
+        bs = self.top_trunk(bs)
+        bs = BracketedSequence(bs.x[:, f_len:, :])
+        bs = self.readout(bs)
+
+        for m in self.modules():
+            if m is not self:
+                self.loss += m.loss
+
+        return bs
+
+
+if args.test == "func":
+    train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+    test_input = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+
+    L = train_input.size(1) // 4
+    f_len = 25
+
+    model = Thinker(
+        vocabulary_size=vocabulary_size,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        f_len=20,
+        dropout=args.dropout,
+    ).to(main_device)
+
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+    for n_epoch in range(args.nb_epochs):
+        model.train()
+
+        nb_train_samples, acc_train_loss = 0, 0.0
+
+        for input in tqdm.tqdm(
+            train_input.split(args.batch_size),
+            dynamic_ncols=True,
+            desc="training",
+            total=train_input.size(0) // args.batch_size,
+        ):
+            input = input.to(main_device)
+
+            if nb_train_samples % args.batch_size == 0:
+                model.optimizer.zero_grad()
+
+            output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
+            targets = input[:, 3 * L :]
+            loss = F.cross_entropy(output.transpose(1, 2), targets)
+            acc_train_loss += loss.item() * input.size(0)
+
+            nb_train_samples += input.size(0)
+
+            loss.backward()
+
+            if nb_train_samples % args.batch_size == 0:
+                model.optimizer.step()
+
+        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+        log_string(f"train_perplexity {n_epoch} model thinker {train_perplexity}")
+
+        with torch.autograd.no_grad():
+            model.eval()
+
+            nb_test_samples, acc_test_loss = 0, 0.0
+
+            for input in tqdm.tqdm(
+                test_input.split(args.batch_size),
+                dynamic_ncols=True,
+                desc="testing",
+                total=test_input.size(0) // args.batch_size,
+            ):
+                input = input.to(main_device)
+
+                output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
+                targets = input[:, 3 * L :]
+                loss = F.cross_entropy(output.transpose(1, 2), targets)
+                acc_test_loss += loss.item() * input.size(0)
+
+                nb_test_samples += input.size(0)
+
+            test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+            log_string(f"test_perplexity {n_epoch} model thinker {test_perplexity}")
+
+            input = test_input[:128].clone().to(main_device)
+
+            output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
+            dist = torch.distributions.categorical.Categorical(logits=output)
+            input[:, 3 * L :] = dist.sample()
+
+
 ######################################################################
 
 models = []