Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 13:49:23 +0000 (15:49 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 13:49:23 +0000 (15:49 +0200)
main.py

diff --git a/main.py b/main.py
index 6b137bf..77976a4 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -372,24 +372,18 @@ def masked_cross_entropy(output, targets, masks):
 
 ######################################################################
 
+# Make args.nb_hints holes in the mask and copy the corresponding cell
+# values from the target to the input
 
-def add_hints_(imt_set):
-    input, masks, targets = imt_set
+
+def add_hints(imt_set):
+    input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
     h = torch.rand(masks.size(), device=masks.device) - masks
     t = h.sort(dim=1).values[:, args.nb_hints, None]
     mask_hints = (h < t).long()
-    masks[...] = (1 - mask_hints) * masks
-    input[...] = (1 - mask_hints) * input + mask_hints * targets
-
-
-def add_hints(masks, fraction_with_hints):
-    if fraction_with_hints > 0:
-        h = torch.rand(masks.size(), device=masks.device) - masks
-        t = h.sort(dim=1).values[:, args.nb_hints, None]
-        mask_hints = (h < t).long()
-        return (1 - mask_hints) * masks
-    else:
-        return masks
+    masks = (1 - mask_hints) * masks
+    input = (1 - mask_hints) * input + mask_hints * targets
+    return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
 # IMT for input / masks / target
@@ -435,7 +429,7 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
     return torch.cat(record)
 
 
-def predict_full(model, input, local_device=main_device):
+def predict_full(model, input, with_hints=False, local_device=main_device):
     input = input[:, None, :].expand(-1, 4, -1).reshape(-1, input.size(1))
     nb = input.size(0)
     masks = input.new_zeros(input.size())
@@ -445,6 +439,9 @@ def predict_full(model, input, local_device=main_device):
     input = (1 - masks) * targets
     imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
+    if with_hints:
+        imt_set = add_hints(imt_set)
+
     result = ae_predict(model, imt_set, local_device=local_device, desc=None)
     result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
 
@@ -533,15 +530,12 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
         args.c_quiz_multiplier,
     )
 
-    q1, q2 = quizzes.to(local_device).chunk(2)
-
-    imt_set = torch.cat(
-        [
-            batch_for_prediction_imt(q1),
-            batch_for_generation_imt(q2),
-        ]
-    )
-
+    q_p, q_g = quizzes.to(local_device).chunk(2)
+    b_p = batch_for_prediction_imt(q_p)
+    i = torch.rand(b_p.size(0)) < 0.5
+    b_p[i] = add_hints(b_p[i])
+    b_g = batch_for_generation_imt(q_g)
+    imt_set = torch.cat([b_p, b_g])
     imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
 
     if train:
@@ -679,7 +673,7 @@ def evaluate_quizzes(quizzes, models, fraction_with_hints, local_device):
         result = predict_full(
             model=model,
             input=quizzes,
-            fraction_with_hints=fraction_with_hints,
+            with_hints=True,
             local_device=local_device,
         )
         nb_mistakes = (result != quizzes).long().sum(dim=1)