From: Francois Fleuret Date: Sat, 2 Jul 2022 19:07:54 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=56c850d27962d5132ac855da677594272d92161b;p=mygpt.git Update. --- diff --git a/main.py b/main.py index 3bf7587..85cf4cf 100755 --- a/main.py +++ b/main.py @@ -69,6 +69,9 @@ parser.add_argument('--dropout', parser.add_argument('--synthesis_sampling', type = bool, default = True) +parser.add_argument('--checkpoint_name', + type = str, default = 'checkpoint.pth') + ###################################################################### args = parser.parse_args() @@ -366,11 +369,11 @@ model = mygpt.MyGPT( nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout ) +model.to(device) + nb_parameters = sum(p.numel() for p in model.parameters()) log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') -model.to(device) - ###################################################################### if args.optim == 'sgd': @@ -382,7 +385,27 @@ elif args.optim == 'adamw': else: raise ValueError(f'Unknown optimizer {args.optim}.') -for k in range(args.nb_epochs): +###################################################################### + +nb_epochs_finished = 0 + +try: + checkpoint = torch.load(args.checkpoint_name, map_location = device) + nb_epochs_finished = checkpoint['nb_epochs_finished'] + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + print(f'Checkpoint loaded with {nb_epochs_finished} epochs finished.') + +except FileNotFoundError: + print('Starting from scratch.') + +except: + print('Error when loading the checkpoint.') + exit(1) + +###################################################################### + +for k in range(nb_epochs_finished, args.nb_epochs): model.train() @@ -419,4 +442,12 @@ for k in range(args.nb_epochs): task.produce_results(k, model) + checkpoint = { + 'nb_epochs_finished': k + 1, + 'model_state': model.state_dict(), + 'optimizer_state': optimizer.state_dict() + } + + torch.save(checkpoint, args.checkpoint_name) + ######################################################################