Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 16 Sep 2024 20:16:15 +0000 (22:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 16 Sep 2024 20:16:15 +0000 (22:16 +0200)
grids.py
main.py
quiz_machine.py

index 882c113..490750b 100755 (executable)
--- a/grids.py
+++ b/grids.py
@@ -181,6 +181,16 @@ class Grids(problem.Problem):
 
         return quizzes
 
+    def pure_noise(self, nb, device):
+        result = torch.randint(
+            self.nb_colors, (nb, 4 * (self.height * self.height + 1)), device=device
+        )
+        result.view(nb, 4, -1)[:, 0, 0] = self.token_A
+        result.view(nb, 4, -1)[:, 1, 0] = self.token_f_A
+        result.view(nb, 4, -1)[:, 2, 0] = self.token_B
+        result.view(nb, 4, -1)[:, 3, 0] = self.token_f_B
+        return result
+
     # What a mess
     def reconfigure(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
         if torch.is_tensor(quizzes):
diff --git a/main.py b/main.py
index 0d46aa2..649c889 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -69,10 +69,6 @@ parser.add_argument("--nb_test_alien_samples", type=int, default=0)
 
 parser.add_argument("--nb_c_quizzes", type=int, default=10000)
 
-parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
-
-parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
-
 parser.add_argument("--c_quiz_multiplier", type=int, default=1)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
@@ -115,9 +111,9 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_models", type=int, default=5)
 
-parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
+parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
 
-parser.add_argument("--proba_diffusion_corruption", type=float, default=0.05)
+parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
 
 parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
@@ -336,7 +332,7 @@ def mu_T_sampler(shape, device="cpu"):
 
 
 diffuser = diffusion.Diffuser(
-    mu_T_sampler, args.nb_diffusion_iterations, args.proba_diffusion_corruption
+    mu_T_sampler, args.diffusion_nb_iterations, args.diffusion_proba_corruption
 )
 
 ######################################################################
@@ -470,6 +466,10 @@ def batches(
         )
 
 
+def NTC_channel_cat(*x):
+    return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
+
+
 def NTC_masked_cross_entropy(output, targets, mask):
     loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
     return (loss_per_token * mask).mean()
@@ -567,7 +567,7 @@ def run_test(
 ######################################################################
 
 
-def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device):
+def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device):
     model.train().to(local_device)
     optimizer_to(model.optimizer, local_device)
 
@@ -635,6 +635,193 @@ def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device)
         )
 
 
+######################################################################
+
+
+def batch_prediction(input, proba_hints=0.0):
+    nb = input.size(0)
+    mask_generate = input.new_zeros(input.size())
+    u = F.one_hot(torch.randint(4, (nb,), device=mask_generate.device), num_classes=4)
+    mask_generate.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+
+    if proba_hints > 0:
+        h = torch.rand(input.size(), device=input.device) * mask_generate
+        mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
+        v = torch.rand(nb, device=input.device)[:, None]
+        mask_hints = mask_hints * (v < proba_hints).long()
+        mask_generate = (1 - mask_hints) * mask_generate
+
+    # noise = quiz_machine.problem.pure_noise(nb, input.device)
+    targets = input
+    input = (1 - mask_generate) * targets  # + mask_generate * noise
+
+    return input, targets, mask_generate
+
+
+def predict(model, quizzes, local_device=main_device):
+    model.eval().to(local_device)
+
+    input, targets, mask = batch_prediction(quizzes.to(local_device))
+
+    input_batches = input.reshape(-1, args.physical_batch_size, input.size(1))
+    targets_batches = targets.reshape(-1, args.physical_batch_size, targets.size(1))
+    mask_batches = mask.reshape(-1, args.physical_batch_size, mask.size(1))
+
+    record = []
+
+    for input, targets, mask in tqdm.tqdm(
+        zip(input_batches, targets_batches, mask_batches),
+        dynamic_ncols=True,
+        desc="predict",
+        total=quizzes.size(0) // args.physical_batch_size,
+    ):
+        # noise = quiz_machine.problem.pure_noise(input.size(0), input.device)
+        input = (1 - mask) * input  # + mask * noise
+        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+            logits = model(NTC_channel_cat(input, mask))
+        dist = torch.distributions.categorical.Categorical(logits=logits)
+        result = (1 - mask) * input + mask * dist.sample()
+        record.append(result)
+
+    return torch.cat(record)
+
+
+######################################################################
+
+
+def batch_generation(input):
+    nb = input.size(0)
+    probs_iterations = 0.1 ** torch.linspace(
+        0, 1, args.diffusion_nb_iterations, device=input.device
+    )
+    probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+    probs_iterations = probs_iterations.expand(nb, -1)
+    dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+    t = dist.sample() + 1
+    r = torch.rand(input.size(), device=input.device)
+    proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
+    mask_erased = (r <= proba_erased[:, None]).long()
+
+    noise = quiz_machine.problem.pure_noise(nb, input.device)
+
+    targets = input
+    input = (1 - mask_erased) * input + mask_erased * noise
+    mask_generate = input.new_full(input.size(), 1)
+    mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0
+
+    return input, targets, mask_generate
+
+
+def prioritized_rand(low):
+    x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+    k = torch.rand(low.size(), device=low.device) + low.long()
+    k = k.sort(dim=1).indices
+    y = x.new(x.size())
+    y.scatter_(dim=1, index=k, src=x)
+    return y
+
+
+def generate(model, nb, local_device=main_device):
+    input = quiz_machine.problem.pure_noise(nb, local_device)
+    mask_generate = input.new_full(input.size(), 1)
+    mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0
+
+    changed = True
+    for it in range(self.diffusion_nb_iterations):
+        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+            logits = model(NTC_channel_cat(input, mask_generate))
+        dist = torch.distributions.categorical.Categorical(logits=logits)
+        output = dist.sample()
+
+        r = self.prioritized_rand(input != output)
+        mask_changes = (r <= self.proba_corruption).long()
+        update = (1 - mask_changes) * input + mask_changes * output
+
+        if update.equal(input):
+            break
+        else:
+            changed = changed & (update != input).max(dim=1).values
+            input[changed] = update[changed]
+
+    return input
+
+
+######################################################################
+
+
+def batch_interleave(a, b, perm):
+    return torch.cat([a, b])[perm].reshape(-1, args.physical_batch_size, a.size(1))
+
+
+def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
+    if train:
+        label = "train"
+        model.train().to(local_device)
+        optimizer_to(model.optimizer, local_device)
+    else:
+        label = "test"
+        model.eval().to(local_device)
+
+    nb_samples, acc_loss = 0, 0.0
+
+    quizzes = quiz_machine.quiz_set(
+        args.nb_train_samples if train else args.nb_test_samples,
+        c_quizzes,
+        args.c_quiz_multiplier,
+    )
+
+    input_p, input_g = quizzes.to(local_device).chunk(2)
+    input_p, targets_p, mask_p = batch_prediction(input_p, proba_hints=0.5)
+    input_g, targets_g, mask_g = batch_generation(input_g)
+
+    perm = torch.randperm(quizzes.size(0), device=local_device)
+    input_batches = batch_interleave(input_p, input_g, perm)
+    targets_batches = batch_interleave(targets_p, targets_g, perm)
+    mask_batches = batch_interleave(mask_p, mask_g, perm)
+
+    for input, targets, mask in tqdm.tqdm(
+        zip(input_batches, targets_batches, mask_batches),
+        dynamic_ncols=True,
+        desc=label,
+        total=quizzes.size(0) // args.physical_batch_size,
+    ):
+        if train and nb_samples % args.batch_size == 0:
+            model.optimizer.zero_grad()
+
+        with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+            logits = model(NTC_channel_cat(input, mask))
+
+        loss = NTC_masked_cross_entropy(logits, targets, mask)
+        acc_loss += loss.item() * input.size(0)
+        nb_samples += input.size(0)
+
+        if train:
+            loss.backward()
+
+            if nb_samples % args.batch_size == 0:
+                model.optimizer.step()
+
+    log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
+
+
+def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
+    one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True)
+
+    one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False)
+
+    quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
+    result = predict(model, quizzes).to("cpu")
+
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir,
+        f"culture_prediction_{n_epoch}_{model.id}.png",
+        quizzes=result[:128],
+    )
+
+    nb_correct = (quizzes == result).min(dim=1).values.long().sum()
+    model.test_accuracy = nb_correct / quizzes.size(0)
+
+
 ######################################################################
 
 import attae
@@ -1099,11 +1286,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     # None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
 
     multithread_execution(
-        one_epoch,
-        [
-            (model, quiz_machine, n_epoch, c_quizzes, gpu)
-            for model, gpu in zip(weakest_models, gpus)
-        ],
+        one_train_test_epoch,
+        [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],
     )
 
     # --------------------------------------------------------------------
index f1eb9db..781c1cf 100755 (executable)
@@ -195,6 +195,37 @@ class QuizMachine:
 
     ######################################################################
 
+    def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1):
+        if c_quizzes is None:
+            quizzes = self.problem.generate_w_quizzes(nb_samples)
+        else:
+            if c_quiz_multiplier > 1:
+                n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+                body = c_quizzes.repeat(n, 1)
+                if n < c_quiz_multiplier:
+                    tail = c_quizzes[
+                        torch.randperm(c_quizzes.size(0))[
+                            : nb_samples // 2 - body.size(0)
+                        ]
+                    ]
+                    c_quizzes = torch.cat([body, tail], dim=0)
+                else:
+                    c_quizzes = body
+
+            if c_quizzes.size(0) > nb_samples // 2:
+                i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+                c_quizzes = c_quizzes[i]
+
+            w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+            quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+
+        i = torch.randperm(quizzes.size(0), device=quizzes.device)
+        quizzes = quizzes[i].contiguous()
+
+        return quizzes
+
+    ######################################################################
+
     def make_quiz_mask(self, quizzes, quad_order, quad_mask):
         assert quad_order in [s for s, _, _, _ in self.train_structures]
         return self.problem.make_quiz_mask(