From: François Fleuret Date: Tue, 13 Aug 2024 18:20:21 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b5339d8b12df46f410204e43c6b77cc74f82d954;p=culture.git Update. --- diff --git a/main.py b/main.py index df29152..20acab3 100755 --- a/main.py +++ b/main.py @@ -47,15 +47,15 @@ parser.add_argument("--log_command", type=str, default=None) parser.add_argument("--nb_epochs", type=int, default=10000) -parser.add_argument("--batch_size", type=int, default=None) +parser.add_argument("--batch_size", type=int, default=25) parser.add_argument("--physical_batch_size", type=int, default=None) -parser.add_argument("--inference_batch_size", type=int, default=None) +parser.add_argument("--inference_batch_size", type=int, default=25) -parser.add_argument("--nb_train_samples", type=int, default=None) +parser.add_argument("--nb_train_samples", type=int, default=40000) -parser.add_argument("--nb_test_samples", type=int, default=None) +parser.add_argument("--nb_test_samples", type=int, default=1000) parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) @@ -66,7 +66,7 @@ parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--schedule_free", action="store_true", default=False) # ---------------------------------- -parser.add_argument("--model", type=str, default=None) +parser.add_argument("--model", type=str, default="37M") parser.add_argument("--dim_model", type=int, default=None) @@ -147,20 +147,6 @@ if args.result_dir is None: ###################################################################### -default_args = { - "model": "37M", - "batch_size": 25, - "inference_batch_size": 25, - "nb_train_samples": 40000, - "nb_test_samples": 1000, -} - -for k, v in default_args.items(): - if getattr(args, k) is None: - setattr(args, k, v) - -###################################################################### - default_model_args = { "17K": { "dim_model": 32, @@ -209,8 +195,9 @@ else: ###################################################################### if args.resume: - assert os.path.isdir(args.result_dir) - + if not os.path.isdir(args.result_dir): + print("Trying to resume with a non-existing result dir {args.result_dir}.") + exit(1) else: try: os.mkdir(args.result_dir) @@ -287,6 +274,8 @@ else: assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 +###################################################################### + problem = grids.Grids( max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, chunk_size=100, @@ -316,6 +305,8 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### +# If we need to move an optimizer to a different device + def optimizer_to(optim, device): for param in optim.state.values():