Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 16:40:29 +0000 (18:40 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 16:40:29 +0000 (18:40 +0200)
main.py

diff --git a/main.py b/main.py
index c4ecc49..71adf30 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -42,8 +42,6 @@ parser.add_argument("--resume", action="store_true", default=False)
 
 parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
 
-parser.add_argument("--log_command", type=str, default=None)
-
 # ----------------------------------
 
 parser.add_argument("--nb_epochs", type=int, default=10000)
@@ -58,23 +56,17 @@ parser.add_argument("--nb_train_samples", type=int, default=50000)
 
 parser.add_argument("--nb_test_samples", type=int, default=1000)
 
-parser.add_argument("--nb_train_alien_samples", type=int, default=0)
-
-parser.add_argument("--nb_test_alien_samples", type=int, default=0)
-
-parser.add_argument("--nb_c_quizzes", type=int, default=2500)
+parser.add_argument("--nb_c_quizzes", type=int, default=10000)
 
 parser.add_argument("--c_quiz_multiplier", type=int, default=1)
 
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
-parser.add_argument("--reboot", action="store_true", default=False)
-
 parser.add_argument("--nb_have_to_be_correct", type=int, default=3)
 
 parser.add_argument("--nb_have_to_be_wrong", type=int, default=1)
 
-parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=5)
+parser.add_argument("--nb_mistakes_to_be_wrong", type=int, default=10)
 
 # ----------------------------------
 
@@ -94,10 +86,6 @@ parser.add_argument("--dropout", type=float, default=0.5)
 
 # ----------------------------------
 
-parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
-
-parser.add_argument("--problem", type=str, default="grids")
-
 parser.add_argument("--nb_threads", type=int, default=1)
 
 parser.add_argument("--gpus", type=str, default="all")
@@ -110,20 +98,12 @@ parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
 
 parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
 
-parser.add_argument("--min_succeed_to_validate", type=int, default=2)
-
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
 parser.add_argument("--proba_prompt_noise", type=float, default=0.05)
 
 parser.add_argument("--proba_hint", type=float, default=0.01)
 
-# parser.add_argument("--nb_hints", type=int, default=25)
-
-parser.add_argument("--nb_runs", type=int, default=1)
-
-parser.add_argument("--test", type=str, default=None)
-
 parser.add_argument("--quizzes", type=str, default=None)
 
 ######################################################################
@@ -141,18 +121,6 @@ parser.add_argument(
 
 ######################################################################
 
-parser.add_argument("--sky_height", type=int, default=6)
-
-parser.add_argument("--sky_width", type=int, default=8)
-
-parser.add_argument("--sky_nb_birds", type=int, default=3)
-
-parser.add_argument("--sky_nb_iterations", type=int, default=2)
-
-parser.add_argument("--sky_speed", type=int, default=3)
-
-######################################################################
-
 args = parser.parse_args()
 
 if args.result_dir is None:
@@ -358,7 +326,7 @@ def optimizer_to(optim, device):
 # values from the target to the input
 
 
-def add_hints(imt_set):
+def add_hints_imt(imt_set):
     input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
     # h = torch.rand(masks.size(), device=masks.device) - masks
     # t = h.sort(dim=1).values[:, args.nb_hints, None]
@@ -375,7 +343,7 @@ def add_hints(imt_set):
 # args.proba_prompt_noise
 
 
-def add_noise(imt_set):
+def add_noise_imt(imt_set):
     input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
     noise = quiz_machine.pure_noise(input.size(0), input.device)
     change = (1 - masks) * (
@@ -443,8 +411,8 @@ def predict_full(model, input, with_perturbations=False, local_device=main_devic
     imt_set = torch.cat([input[:, None], masks[:, None], targets[:, None]], dim=1)
 
     if with_perturbations:
-        imt_set = add_hints(imt_set)
-        imt_set = add_noise(imt_set)
+        imt_set = add_hints_imt(imt_set)
+        imt_set = add_noise_imt(imt_set)
 
     result = ae_predict(model, imt_set, local_device=local_device, desc=None)
     result = (result * masks).reshape(-1, 4, result.size(1)).sum(dim=1)
@@ -542,11 +510,17 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
     )
 
     q_p, q_g = quizzes.to(local_device).chunk(2)
+
+    # Half of the samples train the prediction, and we inject noise in
+    # all, and hints in half
     b_p = batch_for_prediction_imt(q_p)
     i = torch.rand(b_p.size(0)) < 0.5
-    b_p = add_noise(b_p)
-    b_p[i] = add_hints(b_p[i])
+    b_p = add_noise_imt(b_p)
+    b_p[i] = add_hints_imt(b_p[i])
+
+    # The other half are denoising examples for the generation
     b_g = batch_for_generation_imt(q_g)
+
     imt_set = torch.cat([b_p, b_g])
     imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
 
@@ -642,7 +616,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     model.test_accuracy = nb_correct / nb_total
 
     log_string(
-        f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy:.02f}%)"
+        f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy*100:.02f}%)"
     )
 
     # Save some images of the ex nihilo generation of the four grids
@@ -782,9 +756,7 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
 ######################################################################
 
 
-def save_quiz_image(
-    models, c_quizzes, filename, solvable_only=False, local_device=main_device
-):
+def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
     c_quizzes = c_quizzes.to(local_device)
 
     to_keep, nb_correct, nb_wrong = evaluate_quizzes(
@@ -794,11 +766,6 @@ def save_quiz_image(
         local_device=local_device,
     )
 
-    if solvable_only:
-        c_quizzes = c_quizzes[to_keep]
-        nb_correct = nb_correct[to_keep]
-        nb_wrong = nb_wrong[to_keep]
-
     comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
 
     quiz_machine.problem.save_quizzes_as_image(
@@ -821,29 +788,31 @@ if args.resume:
     for model in models:
         filename = f"ae_{model.id:03d}.pth"
 
-        try:
-            d = torch.load(os.path.join(args.result_dir, filename), map_location="cpu")
-            model.load_state_dict(d["state_dict"])
-            model.optimizer.load_state_dict(d["optimizer_state_dict"])
-            model.test_accuracy = d["test_accuracy"]
-            # model.gen_test_accuracy = d["gen_test_accuracy"]
-            # model.gen_state_dict = d["gen_state_dict"]
-            # model.train_c_quiz_bags = d["train_c_quiz_bags"]
-            # model.test_c_quiz_bags = d["test_c_quiz_bags"]
-            log_string(f"successfully loaded {filename}")
-        except FileNotFoundError:
-            log_string(f"cannot find {filename}")
-            pass
-
-    try:
-        filename = "state.pth"
-        state = torch.load(os.path.join(args.result_dir, filename))
+        d = torch.load(
+            os.path.join(args.result_dir, filename),
+            map_location="cpu",
+            weights_only=False,
+        )
+        model.load_state_dict(d["state_dict"])
+        model.optimizer.load_state_dict(d["optimizer_state_dict"])
+        model.test_accuracy = d["test_accuracy"]
+        # model.gen_test_accuracy = d["gen_test_accuracy"]
+        # model.gen_state_dict = d["gen_state_dict"]
+        # model.train_c_quiz_bags = d["train_c_quiz_bags"]
+        # model.test_c_quiz_bags = d["test_c_quiz_bags"]
         log_string(f"successfully loaded {filename}")
-        current_epoch = state["current_epoch"]
-        c_quizzes = state["c_quizzes"]
-    except FileNotFoundError:
-        log_string(f"cannot find {filename}")
-        pass
+
+    filename = "state.pth"
+    state = torch.load(
+        os.path.join(args.result_dir, filename),
+        map_location="cpu",
+        weights_only=False,
+    )
+
+    log_string(f"successfully loaded {filename}")
+
+    current_epoch = state["current_epoch"]
+    c_quizzes = state["c_quizzes"]
 
 ######################################################################
 
@@ -918,9 +887,8 @@ def multithread_execution(fun, arguments):
     for t in threads:
         t.join()
 
-    if records[0] is None:
+    if records[0] == (None,):
         return
-
     else:
         return [
             torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
@@ -944,7 +912,8 @@ def save_models(models, suffix=""):
             },
             os.path.join(args.result_dir, filename),
         )
-        log_string(f"wrote {filename}")
+
+    log_string(f"wrote ae_*{prefix}.pth")
 
 
 ######################################################################
@@ -983,20 +952,10 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         )
 
         save_quiz_image(
-            models,
-            new_c_quizzes[:256],
-            f"culture_c_quiz_{n_epoch:04d}.png",
-            solvable_only=False,
-        )
-
-        save_quiz_image(
-            models,
-            new_c_quizzes[:256],
-            f"culture_c_quiz_{n_epoch:04d}_solvable.png",
-            solvable_only=True,
+            models, new_c_quizzes[:256], f"culture_c_quiz_{n_epoch:04d}.png"
         )
 
-        log_string(f"generated_c_quizzes {new_c_quizzes.size()=}")
+        log_string(f"generated_c_quizzes {new_c_quizzes.size()}")
 
         c_quizzes = (
             new_c_quizzes