Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 15:03:33 +0000 (17:03 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 15:03:33 +0000 (17:03 +0200)
main.py

diff --git a/main.py b/main.py
index 704707d..182b907 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -114,9 +114,11 @@ parser.add_argument("--min_succeed_to_validate", type=int, default=2)
 
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
-parser.add_argument("--prompt_noise", type=float, default=0.05)
+parser.add_argument("--prompt_noise_proba", type=float, default=0.05)
 
-parser.add_argument("--nb_hints", type=int, default=25)
+parser.add_argument("--hint_proba", type=float, default=0.01)
+
+# parser.add_argument("--nb_hints", type=int, default=25)
 
 parser.add_argument("--nb_runs", type=int, default=1)
 
@@ -358,23 +360,26 @@ def optimizer_to(optim, device):
 
 def add_hints(imt_set):
     input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
-    h = torch.rand(masks.size(), device=masks.device) - masks
-    t = h.sort(dim=1).values[:, args.nb_hints, None]
-    mask_hints = (h < t).long()
+    # h = torch.rand(masks.size(), device=masks.device) - masks
+    # t = h.sort(dim=1).values[:, args.nb_hints, None]
+    # mask_hints = (h < t).long()
+    mask_hints = (
+        torch.rand(input.size(), device=input.device) < args.hint_proba
+    ).long() * masks
     masks = (1 - mask_hints) * masks
     input = (1 - mask_hints) * input + mask_hints * targets
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
 
 # Make pixels from the available input (mask=0) noise with probability
-# args.prompt_noise
+# args.prompt_noise_proba
 
 
 def add_noise(imt_set):
     input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
     noise = quiz_machine.pure_noise(input.size(0), input.device)
     change = (1 - masks) * (
-        torch.rand(input.size(), device=input.device) < args.prompt_noise
+        torch.rand(input.size(), device=input.device) < args.prompt_noise_proba
     ).long()
     input = (1 - change) * input + change * noise
     return torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
@@ -403,14 +408,14 @@ def ae_predict(model, imt_set, local_device=main_device, desc="predict"):
 
     record = []
 
-    src = imt_set.split(args.train_batch_size)
+    src = imt_set.split(args.eval_batch_size)
 
     if desc is not None:
         src = tqdm.tqdm(
             src,
             dynamic_ncols=True,
             desc=desc,
-            total=imt_set.size(0) // args.train_batch_size,
+            total=imt_set.size(0) // args.eval_batch_size,
         )
 
     for imt in src:
@@ -502,9 +507,9 @@ def ae_generate(model, nb, local_device=main_device):
         sub_changed = all_changed[all_changed].clone()
 
         src = zip(
-            sub_input.split(args.train_batch_size),
-            sub_masks.split(args.train_batch_size),
-            sub_changed.split(args.train_batch_size),
+            sub_input.split(args.eval_batch_size),
+            sub_masks.split(args.eval_batch_size),
+            sub_changed.split(args.eval_batch_size),
         )
 
         for input, masks, changed in src:
@@ -549,17 +554,19 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
         label = "train"
         model.train().to(local_device)
         optimizer_to(model.optimizer, local_device)
+        batch_size = args.train_batch_size
     else:
         label = "test"
         model.eval().to(local_device)
+        batch_size = args.eval_batch_size
 
     nb_samples, acc_loss = 0, 0.0
 
     for imt in tqdm.tqdm(
-        imt_set.split(args.train_batch_size),
+        imt_set.split(batch_size),
         dynamic_ncols=True,
         desc=label,
-        total=quizzes.size(0) // args.train_batch_size,
+        total=quizzes.size(0) // batch_size,
     ):
         input, masks, targets = imt[:, 0], imt[:, 1], imt[:, 2]
         if train and nb_samples % args.batch_size == 0:
@@ -716,7 +723,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
         generator_id = model.id
 
         c_quizzes = ae_generate(
-            model=model, nb=args.train_batch_size * 10, local_device=local_device
+            model=model, nb=args.eval_batch_size * 10, local_device=local_device
         )
 
         # Select the ones that are solved properly by some models and