From 8558a1047d53598c79f6c9052e6d0282c93a218e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 13 Jul 2024 07:21:40 +0200 Subject: [PATCH] Update. --- main.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index a8ceac8..9599cf3 100755 --- a/main.py +++ b/main.py @@ -78,10 +78,6 @@ parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_gpts", type=int, default=5) -parser.add_argument("--min_to_validate", type=int, default=None) - -parser.add_argument("--max_to_validate", type=int, default=None) - parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) parser.add_argument("--proba_understands", type=float, default=0.99) @@ -121,12 +117,6 @@ parser.add_argument("--sky_speed", type=int, default=3) args = parser.parse_args() -if args.min_to_validate is None: - args.min_to_validate = args.nb_gpts - 1 - -if args.max_to_validate is None: - args.max_to_validate = args.nb_gpts - 1 - if args.result_dir is None: args.result_dir = f"results_culture" @@ -338,10 +328,10 @@ def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_de def one_epoch(model, quiz_machine, local_device=main_device): - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) - model.to(local_device).train() + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + nb_train_samples, acc_train_loss = 0, 0.0 for input in quiz_machine.batches(model, split="train"): -- 2.39.5