Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 14:05:48 +0000 (16:05 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 14:05:48 +0000 (16:05 +0200)
main.py

diff --git a/main.py b/main.py
index 77976a4..d8dffe2 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -351,26 +351,6 @@ def optimizer_to(optim, device):
 
 ######################################################################
 
-# quad_order, quad_generate, quad_noise, quad_loss
-
-data_structures = [
-    (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
-    (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
-    (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
-    (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
-    (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-]
-
-
-######################################################################
-
-
-def masked_cross_entropy(output, targets, masks):
-    loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
-    return (loss_per_token * masks).mean()
-
-
-######################################################################
 
 # Make args.nb_hints holes in the mask and copy the corresponding cell
 # values from the target to the input
@@ -386,6 +366,20 @@ def add_hints(imt_set):
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
+# Make pixels from the available input (mask=0) noise with probability
+# args.prompt_noise
+
+
+def add_noise(imt_set):
+    input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
+    noise = quiz_machine.pure_noise(input.size(0), input.device)
+    change = (1 - masks) * (
+        torch.rand(input.size(), device=input.device) < args.prompt_noise
+    ).long()
+    input = (1 - change) * input + change * noise
+    return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
+
+
 # IMT for input / masks / target
 
 
@@ -429,7 +423,7 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
     return torch.cat(record)
 
 
-def predict_full(model, input, with_hints=False, local_device=main_device):
+def predict_full(model, input, with_perturbations=False, local_device=main_device):
     input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
     nb = input.size(0)
     masks = input.new_zeros(input.size())
@@ -439,8 +433,9 @@ def predict_full(model, input, with_hints=False, local_device=main_device):
     input = (1 - masks) * targets
     imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
-    if with_hints:
+    if with_perturbations:
         imt_set = add_hints(imt_set)
+        imt_set = add_noise(imt_set)
 
     result = ae_predict(model, imt_set, local_device=local_device, desc=None)
     result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
@@ -533,6 +528,7 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
     q_p, q_g = quizzes.to(local_device).chunk(2)
     b_p = batch_for_prediction_imt(q_p)
     i = torch.rand(b_p.size(0)) < 0.5
+    b_p = add_noise(b_p)
     b_p[i] = add_hints(b_p[i])
     b_g = batch_for_generation_imt(q_g)
     imt_set = torch.cat([b_p, b_g])
@@ -554,13 +550,17 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
         desc=label,
         total=quizzes.size(0) // args.physical_batch_size,
     ):
+        input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2]
         if train and nb_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
         with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
-            logits = model(imt[:, 0] * 2 + imt[:, 1])
+            logits = model(input * 2 + masks)
 
-        loss = masked_cross_entropy(logits, targets=imt[:, 2], masks=imt[:, 1])
+        loss_per_token = F.cross_entropy(
+            logits.transpose(1, 2), targets, reduction="none"
+        )
+        loss = (loss_per_token * masks).mean()
         acc_loss += loss.item() * imt.size(0)
         nb_samples += imt.size(0)
 
@@ -673,7 +673,7 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
         result = predict_full(
             model=model,
             input=quizzes,
-            with_hints=True,
+            with_perturbations=True,
             local_device=local_device,
         )
         nb_mistakes = (result != quizzes).long().sum(dim=1)