Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 24 Aug 2024 14:02:25 +0000 (16:02 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 24 Aug 2024 14:02:25 +0000 (16:02 +0200)
main.py

diff --git a/main.py b/main.py
index 7ce9b03..d3d237e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -67,8 +67,6 @@ parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 parser.add_argument("--reboot", action="store_true", default=False)
 
-parser.add_argument("--schedule_free", action="store_true", default=False)
-
 # ----------------------------------
 parser.add_argument("--model", type=str, default="37M")
 
@@ -335,8 +333,6 @@ def optimizer_to(optim, device):
 def run_tests(model, quiz_machine, local_device=main_device):
     with torch.autograd.no_grad():
         model.to(local_device).eval()
-        if args.schedule_free:
-            model.optimizer.eval()
 
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
@@ -389,9 +385,6 @@ def one_epoch(model, quiz_machine, local_device=main_device):
     model.to(local_device).train()
     optimizer_to(model.optimizer, local_device)
 
-    if args.schedule_free:
-        model.optimizer.train()
-
     nb_train_samples, acc_train_loss = 0, 0.0
 
     full_input, _, full_mask_loss = quiz_machine.data_input(
@@ -829,6 +822,8 @@ class MyAttentionAE(nn.Module):
                     m.weight.fill_(1.0)
 
     def forward(self, bs):
+        if torch.is_tensor(bs):
+            return self.forward(BracketedSequence(bs)).x
         bs = self.embedding(bs)
         bs = self.positional_encoding(bs)
         bs = self.trunk(bs)
@@ -836,15 +831,22 @@ class MyAttentionAE(nn.Module):
         return bs
 
 
-def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
+def ae_batches(
+    quiz_machine,
+    nb,
+    data_structures,
+    local_device,
+    desc=None,
+    batch_size=args.batch_size,
+):
     full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
         nb, data_structures=data_structures
     )
 
     src = zip(
-        full_input.split(args.batch_size),
-        full_mask_generate.split(args.batch_size),
-        full_mask_loss.split(args.batch_size),
+        full_input.split(batch_size),
+        full_mask_generate.split(batch_size),
+        full_mask_loss.split(batch_size),
     )
 
     if desc is not None:
@@ -852,7 +854,7 @@ def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
             src,
             dynamic_ncols=True,
             desc=desc,
-            total=full_input.size(0) // args.batch_size,
+            total=full_input.size(0) // batch_size,
         )
 
     for input, mask_generate, mask_loss in src:
@@ -863,34 +865,60 @@ def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
         )
 
 
-def degrade_input_inplace(input, mask_generate, pure_noise=False):
-    if pure_noise:
-        mask_diffusion_noise = torch.rand(
-            mask_generate.size(), device=mask_generate.device
-        ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+def degrade_input(input, mask_generate, *ts):
+    noise = torch.randint(
+        quiz_machine.problem.nb_colors, input.size(), device=input.device
+    )
 
-        mask_diffusion_noise = mask_diffusion_noise.long()
+    r = torch.rand(mask_generate.size(), device=mask_generate.device)
 
-        input[...] = (
-            mask_generate
-            * mask_diffusion_noise
-            * torch.randint(
-                quiz_machine.problem.nb_colors, input.size(), device=input.device
-            )
-            + (1 - mask_generate * mask_diffusion_noise) * input
-        )
+    result = []
 
-    else:
-        model.eval()
-        for it in range(torch.randint(5, (1,)).item()):
-            logits = model(
-                mygpt.BracketedSequence(
-                    torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
-                )
-            ).x
-            dist = torch.distributions.categorical.Categorical(logits=logits)
-            input[...] = (1 - mask_generate) * input + mask_generate * dist.sample()
-        model.train()
+    for t in ts:
+        mask_diffusion_noise = mask_generate * (r <= t).long()
+        x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
+        result.append(x)
+
+    return result
+
+    # quiz_machine.problem.save_quizzes_as_image(
+    # args.result_dir,
+    # filename="a.png",
+    # quizzes=a,
+    # )
+
+    # quiz_machine.problem.save_quizzes_as_image(
+    # args.result_dir,
+    # filename="b.png",
+    # quizzes=b,
+    # )
+
+    # time.sleep(1000)
+
+
+def NTC_masked_cross_entropy(output, targets, mask):
+    loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
+    return (loss_per_token * mask).mean()
+
+
+def NTC_channel_cat(*x):
+    return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
+
+
+def ae_generate(model, input, mask_generate, n_epoch, nb_iterations):
+    noise = torch.randint(
+        quiz_machine.problem.nb_colors, input.size(), device=input.device
+    )
+    input = (1 - mask_generate) * input + mask_generate * noise
+
+    for it in range(nb_iterations):
+        rho = input.new_full((input.size(0),), nb_iterations - 1 - it)
+        input_with_mask = NTC_channel_cat(input, mask_generate, rho[:, None])
+        logits = model(input_with_mask)
+        dist = torch.distributions.categorical.Categorical(logits=logits)
+        input = (1 - mask_generate) * input + mask_generate * dist.sample()
+
+    return input
 
 
 def test_ae(local_device=main_device):
@@ -904,9 +932,8 @@ def test_ae(local_device=main_device):
         dropout=args.dropout,
     ).to(main_device)
 
-    pure_noise = True
-
     # quad_order, quad_generate, quad_noise, quad_loss
+
     data_structures = [
         (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
         (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
@@ -920,8 +947,7 @@ def test_ae(local_device=main_device):
     model.to(local_device).train()
     optimizer_to(model.optimizer, local_device)
 
-    if args.schedule_free:
-        model.optimizer.train()
+    nb_iterations = 5
 
     for n_epoch in range(args.nb_epochs):
         # ----------------------
@@ -940,34 +966,16 @@ def test_ae(local_device=main_device):
             if nb_train_samples % args.batch_size == 0:
                 model.optimizer.zero_grad()
 
-            targets = input.clone()
-            degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
-
-            output = model(
-                mygpt.BracketedSequence(
-                    torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
-                )
-            ).x
-
-            # for filename, quizzes in [
-            # ("targets.png", targets),
-            # ("input.png", input),
-            # ("mask_generate.png", mask_generate),
-            # ("mask_loss.png", mask_loss),
-            # ]:
-            # quiz_machine.problem.save_quizzes_as_image(
-            # args.result_dir,
-            # filename,
-            # quizzes=quizzes,
-            # )
-            # time.sleep(10000)
-
-            loss_per_token = F.cross_entropy(
-                output.transpose(1, 2), targets, reduction="none"
+            rho = torch.randint(nb_iterations, (input.size(0), 1), device=input.device)
+            targets, input = degrade_input(
+                input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations
             )
-            loss = (loss_per_token * mask_loss).mean()
+            input_with_mask = NTC_channel_cat(input, mask_generate, rho)
+            output = model(input_with_mask)
+            loss = NTC_masked_cross_entropy(output, targets, mask_loss)
             acc_train_loss += loss.item() * input.size(0)
             nb_train_samples += input.size(0)
+
             loss.backward()
 
             if nb_train_samples % args.batch_size == 0:
@@ -992,17 +1000,15 @@ def test_ae(local_device=main_device):
                 local_device,
                 "test",
             ):
-                targets = input.clone()
-                degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
-                output = model(
-                    mygpt.BracketedSequence(
-                        torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
-                    )
-                ).x
-                loss_per_token = F.cross_entropy(
-                    output.transpose(1, 2), targets, reduction="none"
+                rho = torch.randint(
+                    nb_iterations, (input.size(0), 1), device=input.device
+                )
+                targets, input = degrade_input(
+                    input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations
                 )
-                loss = (loss_per_token * mask_loss).mean()
+                input_with_mask = NTC_channel_cat(input, mask_generate, rho)
+                output = model(input_with_mask)
+                loss = NTC_masked_cross_entropy(output, targets, mask_loss)
                 acc_test_loss += loss.item() * input.size(0)
                 nb_test_samples += input.size(0)
 
@@ -1014,73 +1020,36 @@ def test_ae(local_device=main_device):
             for ns, s in enumerate(data_structures):
                 quad_order, quad_generate, _, _ = s
 
-                input, mask_generate, mask_loss = next(
-                    ae_batches(quiz_machine, 128, [s], local_device)
+                input, mask_generate, _ = next(
+                    ae_batches(quiz_machine, 128, [s], local_device, batch_size=128)
                 )
 
                 targets = input.clone()
-                degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
-                result = input
-
-                not_converged = torch.full(
-                    (result.size(0),), True, device=result.device
-                )
-
-                for it in range(100):
-                    pred_result = result.clone()
-                    logits = model(
-                        mygpt.BracketedSequence(
-                            torch.cat(
-                                [
-                                    result[not_converged, :, None],
-                                    mask_generate[not_converged, :, None],
-                                ],
-                                dim=2,
-                            )
-                        )
-                    ).x
-                    dist = torch.distributions.categorical.Categorical(logits=logits)
-                    update = (1 - mask_generate[not_converged]) * input[
-                        not_converged
-                    ] + mask_generate[not_converged] * dist.sample()
-                    result[not_converged] = update
-                    not_converged = (pred_result != result).max(dim=1).values
-                    if not not_converged.any():
-                        log_string(f"diffusion_converged {it=}")
-                        break
-
-                correct = (result == targets).min(dim=1).values.long()
-                predicted_parts = input.new(input.size(0), 4)
-
-                nb = 0
-
-                predicted_parts = torch.tensor(quad_generate, device=result.device)[
-                    None, :
-                ]
+                input = ae_generate(model, input, mask_generate, n_epoch, nb_iterations)
+                correct = (input == targets).min(dim=1).values.long()
+                predicted_parts = torch.tensor(quad_generate, device=input.device)
+                predicted_parts = predicted_parts[None, :].expand(input.size(0), -1)
                 solution_is_deterministic = predicted_parts.sum(dim=-1) == 1
                 correct = (2 * correct - 1) * (solution_is_deterministic).long()
-
                 nb_correct = (correct == 1).long().sum()
                 nb_total = (correct != 0).long().sum()
+                correct_parts = predicted_parts * correct[:, None]
 
                 log_string(
                     f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
                 )
 
-                correct_parts = predicted_parts * correct[:, None]
-                predicted_parts = predicted_parts.expand_as(correct_parts)
-
-                filename = f"prediction_ae_{n_epoch:04d}_{ns}.png"
+                filename = f"prediction_ae_{n_epoch:04d}_structure_{ns}.png"
 
                 quiz_machine.problem.save_quizzes_as_image(
                     args.result_dir,
                     filename,
-                    quizzes=result,
+                    quizzes=input,
                     predicted_parts=predicted_parts,
                     correct_parts=correct_parts,
                 )
 
-                log_string(f"wrote {filename}")
+            log_string(f"wrote {filename}")
 
 
 if args.test == "ae":
@@ -1096,9 +1065,6 @@ def create_models():
     def compute_causal_attzero(t_q, t_k):
         return t_q < t_k
 
-    if args.schedule_free:
-        import schedulefree
-
     for k in range(args.nb_gpts):
         log_string(f"creating model {k}")
 
@@ -1132,14 +1098,7 @@ def create_models():
         model.train_c_quiz_bags = []
         model.test_c_quiz_bags = []
 
-        if args.schedule_free:
-            model.optimizer = schedulefree.AdamWScheduleFree(
-                model.parameters(), lr=args.learning_rate
-            )
-        else:
-            model.optimizer = torch.optim.Adam(
-                model.parameters(), lr=args.learning_rate
-            )
+        model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
         model.test_accuracy = 0.0
         model.gen_test_accuracy = 0.0