From: François Fleuret Date: Fri, 16 Aug 2024 17:53:55 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=df1da600ffc7cdf5f8bdfbfcc098515fe57e8e2a;p=culture.git Update. --- diff --git a/main.py b/main.py index f1fb834..dc3fd1a 100755 --- a/main.py +++ b/main.py @@ -945,8 +945,8 @@ if args.resume: state = torch.load(os.path.join(args.result_dir, filename)) log_string(f"successfully loaded {filename}") current_epoch = state["current_epoch"] - total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"] - total_time_training_models = state["total_time_training_models"] + # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"] + # total_time_training_models = state["total_time_training_models"] except FileNotFoundError: log_string(f"cannot find {filename}") pass @@ -1161,7 +1161,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): if ( min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes - and total_time_training_models > total_time_generating_c_quizzes + and total_time_training_models >= total_time_generating_c_quizzes ): for model in models: model.current_dict = copy.deepcopy(model.state_dict()) @@ -1183,7 +1183,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): ################################################## # Select, improve, and eval the worst model(s) - if total_time_training_models <= total_time_generating_c_quizzes: + if total_time_training_models < total_time_generating_c_quizzes: ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] @@ -1191,6 +1191,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): threads = [] start_time = time.perf_counter() + for gpu, model in zip(gpus, weakest_models): log_string(f"training model {model.id}")