From cf01c3df0d1e693bdee02e01c20bb33b3dfffe67 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 16 Aug 2024 19:52:32 +0200 Subject: [PATCH] Update. --- main.py | 52 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index 0ca8153..f1fb834 100755 --- a/main.py +++ b/main.py @@ -919,6 +919,8 @@ for k in range(args.nb_gpts): ###################################################################### current_epoch = 0 +total_time_generating_c_quizzes = 0 +total_time_training_models = 0 if args.resume: for model in models: @@ -943,6 +945,8 @@ if args.resume: state = torch.load(os.path.join(args.result_dir, filename)) log_string(f"successfully loaded {filename}") current_epoch = state["current_epoch"] + total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"] + total_time_training_models = state["total_time_training_models"] except FileNotFoundError: log_string(f"cannot find {filename}") pass @@ -973,7 +977,6 @@ if args.dirty_debug: args.nb_new_c_quizzes_for_test = 10 ###################################################################### -###################################################################### class Folder(nn.Module): @@ -1126,7 +1129,11 @@ if args.test == "reject": ###################################################################### for n_epoch in range(current_epoch, args.nb_epochs): - state = {"current_epoch": n_epoch} + state = { + "current_epoch": n_epoch, + "total_time_training_models": total_time_training_models, + "total_time_generating_c_quizzes": total_time_generating_c_quizzes, + } filename = "state.pth" torch.save(state, os.path.join(args.result_dir, filename)) log_string(f"wrote {filename}") @@ -1148,17 +1155,26 @@ for n_epoch in range(current_epoch, args.nb_epochs): model.best_dict = copy.deepcopy(model.state_dict()) model.best_test_accuracy = model.test_accuracy - if min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: + # we restart + if total_time_generating_c_quizzes == 0: + total_time_training_models = 0 + + if ( + min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes + and total_time_training_models > total_time_generating_c_quizzes + ): for model in models: model.current_dict = copy.deepcopy(model.state_dict()) model.load_state_dict(model.best_dict) + start_time = time.perf_counter() record_new_c_quizzes( models, quiz_machine, args.nb_new_c_quizzes_for_train, args.nb_new_c_quizzes_for_test, ) + total_time_generating_c_quizzes += time.perf_counter() - start_time # Force one epoch of training for model in models: @@ -1167,29 +1183,33 @@ for n_epoch in range(current_epoch, args.nb_epochs): ################################################## # Select, improve, and eval the worst model(s) - ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) + if total_time_training_models <= total_time_generating_c_quizzes: + ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) - weakest_models = ranked_models[: len(gpus)] + weakest_models = ranked_models[: len(gpus)] - threads = [] + threads = [] - for gpu, model in zip(gpus, weakest_models): - log_string(f"training model {model.id}") + start_time = time.perf_counter() + for gpu, model in zip(gpus, weakest_models): + log_string(f"training model {model.id}") - t = threading.Thread( - target=one_epoch, daemon=True, args=(model, quiz_machine, gpu) - ) + t = threading.Thread( + target=one_epoch, daemon=True, args=(model, quiz_machine, gpu) + ) + + threads.append(t) - threads.append(t) + t.start() - t.start() + for t in threads: + t.join() - for t in threads: - t.join() + total_time_training_models += time.perf_counter() - start_time # Save the models to disk - for model in weakest_models: + for model in models: filename = f"gpt_{model.id:03d}.pth" torch.save( { -- 2.39.5