Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 14:16:03 +0000 (16:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 14:16:03 +0000 (16:16 +0200)
main.py

diff --git a/main.py b/main.py
index c08b04d..7c8c836 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -249,16 +249,6 @@ assert args.nb_test_samples % args.batch_size == 0
 
 ######################################################################
 
-problem = grids.Grids(
-    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
-    chunk_size=100,
-    nb_threads=args.nb_threads,
-    tasks=args.grids_world_tasks,
-)
-
-if not args.resume:
-    problem.save_some_examples(args.result_dir)
-
 
 def pure_noise(nb, device):
     r = problem.pure_noise(nb, device)
@@ -308,14 +298,6 @@ def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
 
 ######################################################################
 
-log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
-
-vocabulary_size = problem.vocabulary_size()
-
-log_string(f"vocabulary_size {vocabulary_size}")
-
-######################################################################
-
 
 def optimizer_to(optim, device):
     """Move the optimizer optim to the device"""
@@ -637,30 +619,6 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
 ######################################################################
 
-models = []
-
-for i in range(args.nb_models):
-    # model = attae.FunctionalAttentionAE(
-    model = attae.AttentionAE(
-        vocabulary_size=vocabulary_size * 2,
-        dim_model=args.dim_model,
-        dim_keys=args.dim_keys,
-        dim_hidden=args.dim_hidden,
-        nb_heads=args.nb_heads,
-        nb_blocks=args.nb_blocks,
-        dropout=args.dropout,
-    )
-
-    # model = torch.compile(model)
-
-    model.id = i
-    model.test_accuracy = 0.0
-    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
-    models.append(model)
-
-######################################################################
-
 
 def evaluate_quizzes(quizzes, models, local_device):
     nb_correct, nb_wrong = 0, 0
@@ -768,6 +726,62 @@ def generate_c_quizzes(models, nb_to_generate, local_device=main_device):
 ######################################################################
 
 
+def multithread_execution(fun, arguments):
+    # Single instance, no thread
+    if len(arguments) == 1:
+        return fun(*(arguments[0]))
+
+    records, threads = [], []
+
+    def threadable_fun(*args):
+        r = fun(*args)
+        if type(r) is not tuple:
+            r = (r,)
+        records.append(r)
+
+    for args in arguments:
+        # To get a different sequence between threads
+        # log_string(f"dummy_rand {torch.rand(1)}")
+        torch.rand(1)
+        t = threading.Thread(target=threadable_fun, daemon=True, args=args)
+        threads.append(t)
+        t.start()
+
+    for t in threads:
+        t.join()
+
+    if records[0] == (None,):
+        return
+    else:
+        return [
+            torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
+        ]
+
+
+######################################################################
+
+
+def save_models(models, suffix=""):
+    if suffix != "":
+        suffix = "_" + suffix
+
+    for model in models:
+        filename = f"ae_{model.id:03d}{suffix}.pth"
+        torch.save(
+            {
+                "state_dict": model.state_dict(),
+                "optimizer_state_dict": model.optimizer.state_dict(),
+                "test_accuracy": model.test_accuracy,
+            },
+            os.path.join(args.result_dir, filename),
+        )
+
+    log_string(f"wrote ae_*{suffix}.pth")
+
+
+######################################################################
+
+
 def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
     c_quizzes = c_quizzes.to(local_device)
 
@@ -793,6 +807,49 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
 
 ######################################################################
 
+problem = grids.Grids(
+    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+    chunk_size=100,
+    nb_threads=args.nb_threads,
+    tasks=args.grids_world_tasks,
+)
+
+if not args.resume:
+    problem.save_some_examples(args.result_dir)
+
+
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
+
+vocabulary_size = problem.vocabulary_size()
+
+log_string(f"vocabulary_size {vocabulary_size}")
+
+######################################################################
+
+models = []
+
+for i in range(args.nb_models):
+    # model = attae.FunctionalAttentionAE(
+    model = attae.AttentionAE(
+        vocabulary_size=vocabulary_size * 2,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        dropout=args.dropout,
+    )
+
+    # model = torch.compile(model)
+
+    model.id = i
+    model.test_accuracy = 0.0
+    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+    models.append(model)
+
+######################################################################
+
 current_epoch = 0
 
 if args.resume:
@@ -837,62 +894,6 @@ c_quizzes = None
 
 ######################################################################
 
-
-def multithread_execution(fun, arguments):
-    # Single instance, no thread
-    if len(arguments) == 1:
-        return fun(*(arguments[0]))
-
-    records, threads = [], []
-
-    def threadable_fun(*args):
-        r = fun(*args)
-        if type(r) is not tuple:
-            r = (r,)
-        records.append(r)
-
-    for args in arguments:
-        # To get a different sequence between threads
-        # log_string(f"dummy_rand {torch.rand(1)}")
-        torch.rand(1)
-        t = threading.Thread(target=threadable_fun, daemon=True, args=args)
-        threads.append(t)
-        t.start()
-
-    for t in threads:
-        t.join()
-
-    if records[0] == (None,):
-        return
-    else:
-        return [
-            torch.cat([x[k] for x in records], dim=0) for k in range(len(records[0]))
-        ]
-
-
-######################################################################
-
-
-def save_models(models, suffix=""):
-    if suffix != "":
-        suffix = "_" + suffix
-
-    for model in models:
-        filename = f"ae_{model.id:03d}{suffix}.pth"
-        torch.save(
-            {
-                "state_dict": model.state_dict(),
-                "optimizer_state_dict": model.optimizer.state_dict(),
-                "test_accuracy": model.test_accuracy,
-            },
-            os.path.join(args.result_dir, filename),
-        )
-
-    log_string(f"wrote ae_*{suffix}.pth")
-
-
-######################################################################
-
 for n_epoch in range(current_epoch, args.nb_epochs):
     start_time = time.perf_counter()