From: François Fleuret Date: Tue, 27 Aug 2024 14:59:16 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=59e4d17787d71b72c823057ecd8a17dc1ef0c07c;p=culture.git Update. --- diff --git a/main.py b/main.py index 2e8ec43..0ecd492 100755 --- a/main.py +++ b/main.py @@ -1133,7 +1133,66 @@ for i in range(args.nb_models): models.append(model) +###################################################################### + +current_epoch = 0 + +if args.resume: + for model in models: + filename = f"ae_{model.id:03d}.pth" + + try: + d = torch.load(os.path.join(args.result_dir, filename)) + model.load_state_dict(d["state_dict"]) + model.optimizer.load_state_dict(d["optimizer_state_dict"]) + model.test_accuracy = d["test_accuracy"] + # model.gen_test_accuracy = d["gen_test_accuracy"] + # model.gen_state_dict = d["gen_state_dict"] + # model.train_c_quiz_bags = d["train_c_quiz_bags"] + # model.test_c_quiz_bags = d["test_c_quiz_bags"] + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + + try: + filename = "state.pth" + state = torch.load(os.path.join(args.result_dir, filename)) + log_string(f"successfully loaded {filename}") + current_epoch = state["current_epoch"] + # total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"] + # total_time_training_models = state["total_time_training_models"] + # common_c_quiz_bags = state["common_c_quiz_bags"] + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + +###################################################################### + +nb_parameters = sum(p.numel() for p in models[0].parameters()) +log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") + + +###################################################################### + for n_epoch in range(args.nb_epochs): + state = { + "current_epoch": n_epoch, + # "total_time_generating_c_quizzes": total_time_generating_c_quizzes, + # "total_time_training_models": total_time_training_models, + # "common_c_quiz_bags": common_c_quiz_bags, + } + filename = "state.pth" + torch.save(state, os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") + + log_string(f"--- epoch {n_epoch} ----------------------------------------") + + cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models]) + log_string(f"current_test_accuracies {cta}") + + # -------------------------------------------------------------------- + ranked_models = sorted(models, key=lambda m: float(m.test_accuracy)) weakest_models = ranked_models[: len(gpus)] @@ -1155,6 +1214,24 @@ for n_epoch in range(args.nb_epochs): for t in threads: t.join() + # -------------------------------------------------------------------- + + for model in models: + filename = f"ae_{model.id:03d}.pth" + torch.save( + { + "state_dict": model.state_dict(), + "optimizer_state_dict": model.optimizer.state_dict(), + "test_accuracy": model.test_accuracy, + # "gen_test_accuracy": model.gen_test_accuracy, + # "gen_state_dict": model.gen_state_dict, + # "train_c_quiz_bags": model.train_c_quiz_bags, + # "test_c_quiz_bags": model.test_c_quiz_bags, + }, + os.path.join(args.result_dir, filename), + ) + log_string(f"wrote {filename}") + ###################################################################### @@ -1213,14 +1290,14 @@ models = create_models() ###################################################################### -current_epoch = 0 - # We balance the computing time between training the models and # generating c_quizzes total_time_generating_c_quizzes = 0 total_time_training_models = 0 +current_epoch = 0 + if args.resume: for model in models: filename = f"gpt_{model.id:03d}.pth"