Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 18:20:21 +0000 (20:20 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 13 Aug 2024 18:20:21 +0000 (20:20 +0200)
main.py

diff --git a/main.py b/main.py
index df29152..20acab3 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -47,15 +47,15 @@ parser.add_argument("--log_command", type=str, default=None)
 
 parser.add_argument("--nb_epochs", type=int, default=10000)
 
-parser.add_argument("--batch_size", type=int, default=None)
+parser.add_argument("--batch_size", type=int, default=25)
 
 parser.add_argument("--physical_batch_size", type=int, default=None)
 
-parser.add_argument("--inference_batch_size", type=int, default=None)
+parser.add_argument("--inference_batch_size", type=int, default=25)
 
-parser.add_argument("--nb_train_samples", type=int, default=None)
+parser.add_argument("--nb_train_samples", type=int, default=40000)
 
-parser.add_argument("--nb_test_samples", type=int, default=None)
+parser.add_argument("--nb_test_samples", type=int, default=1000)
 
 parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
 
@@ -66,7 +66,7 @@ parser.add_argument("--learning_rate", type=float, default=5e-4)
 parser.add_argument("--schedule_free", action="store_true", default=False)
 
 # ----------------------------------
-parser.add_argument("--model", type=str, default=None)
+parser.add_argument("--model", type=str, default="37M")
 
 parser.add_argument("--dim_model", type=int, default=None)
 
@@ -147,20 +147,6 @@ if args.result_dir is None:
 
 ######################################################################
 
-default_args = {
-    "model": "37M",
-    "batch_size": 25,
-    "inference_batch_size": 25,
-    "nb_train_samples": 40000,
-    "nb_test_samples": 1000,
-}
-
-for k, v in default_args.items():
-    if getattr(args, k) is None:
-        setattr(args, k, v)
-
-######################################################################
-
 default_model_args = {
     "17K": {
         "dim_model": 32,
@@ -209,8 +195,9 @@ else:
 ######################################################################
 
 if args.resume:
-    assert os.path.isdir(args.result_dir)
-
+    if not os.path.isdir(args.result_dir):
+        print("Trying to resume with a non-existing result dir {args.result_dir}.")
+        exit(1)
 else:
     try:
         os.mkdir(args.result_dir)
@@ -287,6 +274,8 @@ else:
 assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
+######################################################################
+
 problem = grids.Grids(
     max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
     chunk_size=100,
@@ -316,6 +305,8 @@ log_string(f"vocabulary_size {vocabulary_size}")
 
 ######################################################################
 
+# If we need to move an optimizer to a different device
+
 
 def optimizer_to(optim, device):
     for param in optim.state.values():