parser.add_argument("--seed", type=int, default=0)
+parser.add_argument("--resume", action="store_true", default=False)
+
parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
########################################
######################################################################
-try:
- os.mkdir(args.result_dir)
-except FileExistsError:
- print(f"result directory {args.result_dir} already exists")
- exit(1)
+if args.resume:
+ assert os.path.isdir(args.result_dir)
+
+else:
+ try:
+ os.mkdir(args.result_dir)
+ except FileExistsError:
+ print(f"result directory {args.result_dir} already exists")
+ exit(1)
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
- # save a bunch of images to investigate what quizzes with a
- # certain nb of correct predictions look like
+ # save images
q = new_c_quizzes[:72]
######################################################################
-nb_loaded_models = 0
-
models = []
for k in range(args.nb_gpts):
model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
- filename = f"gpt_{model.id:03d}.pth"
+ models.append(model)
+
+######################################################################
+if args.resume:
try:
- model.load_state_dict(torch.load(os.path.join(args.result_dir, filename)))
- log_string(f"model {model.id} successfully loaded from checkpoint.")
- nb_loaded_models += 1
-
- except FileNotFoundError:
- log_string(f"starting model {model.id} from scratch.")
+ for model in models:
+ filename = f"gpt_{model.id:03d}.pth"
+
+ try:
+ model.load_state_dict(
+ torch.load(os.path.join(args.result_dir, filename))
+ )
+ log_string(f"successfully loaded {filename}")
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
+
+ try:
+ filename = "c_quizzes.pth"
+ quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+ log_string(f"successfully loaded {filename}")
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
except:
log_string(f"error when loading {filename}.")
exit(1)
- models.append(model)
-
-assert nb_loaded_models == 0 or nb_loaded_models == len(models)
+######################################################################
nb_parameters = sum(p.numel() for p in models[0].parameters())
log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
for model in weakest_models:
filename = f"gpt_{model.id:03d}.pth"
torch.save(model.state_dict(), os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
##################################################
# Replace a fraction of the w_quizzes with fresh ones
nb_for_test=nb_new_c_quizzes_for_test,
)
+ quiz_machine.save_c_quizzes(os.path.join(args.result_dir, "c_quizzes.pth"))
+
######################################################################