######################################################################
 
 
+# ra_mask is boolean, with 1s on the values to generate
+
+
 def masked_inplace_autoregression(
-    model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu")
+    model,
+    batch_size,
+    input,
+    ar_mask,
+    forbidden_tokens=None,
+    progress_bar_desc="autoregression",
+    device=torch.device("cpu"),
 ):
-    for input, ar_mask in tqdm.tqdm(
-        zip(input.split(batch_size), ar_mask.split(batch_size)),
-        dynamic_ncols=True,
-        desc="autoregression",
-        total=input.size(0) // batch_size,
-    ):
+    batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+    if progress_bar_desc is not None:
+        tqdm.tqdm(
+            batches,
+            dynamic_ncols=True,
+            desc=progress_bar_desc,
+            total=input.size(0) // batch_size,
+        )
+    for input, ar_mask in batches:
         i = (ar_mask.sum(0) > 0).nonzero()
         if i.min() > 0:
             model(
                 input,
                 ar_masks,
                 forbidden_tokens,
+                progress_bar_desc=None,
                 device=self.device,
             )
             model.train(t)
         for input in task.batches(split="test"):
             input = input.to(device)
 
-            # input, loss_masks, true_images = task.excise_last_image(input)
-            # input, loss_masks = task.add_true_image(input, true_images, loss_masks)
-
             output = model(mygpt.BracketedSequence(input)).x
             loss = F.cross_entropy(output.transpose(1, 2), input)
             acc_test_loss += loss.item() * input.size(0)
 
 
 # Written by Francois Fleuret <francois@fleuret.org>
 
+# This is an implementation from scratch of a "GPT", that is a model
+# composed of several causal self-attention blocks. It is equipped
+# with a caching mechanism for keys and values to avoid a O(N^3) cost
+# for auto-regression.
+
 import math
 
 import torch
 
 ######################################################################
 
-
-class WithResidual(nn.Module):
-    def __init__(self, *f):
-        super().__init__()
-        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
-
-    def forward(self, bs):
-        bs.x = bs.x + self.f(bs).x
-        return bs
-
-
-######################################################################
-
 # A BracketedSequence is a BxTx... tensor with a first and a nb time
 # steps to compute.
 
 ##############################
 
 
+class WithResidual(nn.Module):
+    def __init__(self, *f):
+        super().__init__()
+        self.f = f[0] if len(f) == 1 else nn.Sequential(*f)
+
+    def forward(self, bs):
+        bs.x = bs.x + self.f(bs).x
+        return bs
+
+
+##############################
+
+
 class AddPositionalEncoding(nn.Module):
     def __init__(self, len_max):
         super().__init__()