From: Francois Fleuret Date: Sun, 7 Aug 2022 19:50:36 +0000 (+0200) Subject: Added the rng state in the checkpoint. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f3a734b6c522b2be0004a1b8bc2fe2eab2a90263;p=mygpt.git Added the rng state in the checkpoint. --- diff --git a/main.py b/main.py index b01ea0a..d4a8cfb 100755 --- a/main.py +++ b/main.py @@ -430,17 +430,6 @@ log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') ###################################################################### -if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) -else: - raise ValueError(f'Unknown optimizer {args.optim}.') - -###################################################################### - nb_epochs_finished = 0 if args.no_checkpoint: @@ -448,10 +437,12 @@ if args.no_checkpoint: else: try: - checkpoint = torch.load(args.checkpoint_name, map_location = device) + checkpoint = torch.load(args.checkpoint_name) nb_epochs_finished = checkpoint['nb_epochs_finished'] model.load_state_dict(checkpoint['model_state']) - optimizer.load_state_dict(checkpoint['optimizer_state']) + torch.set_rng_state(checkpoint['rng_state']) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(checkpoint['cuda_rng_state']) log_string(f'checkpoint loaded with {nb_epochs_finished} epochs finished.') except FileNotFoundError: @@ -472,7 +463,16 @@ token_probas = token_count / token_count.sum() entropy = -torch.xlogy(token_probas, token_probas).sum() train_set_perplexity = math.exp(entropy) -for k in range(nb_epochs_finished, nb_epochs): +for n_epoch in range(nb_epochs_finished, nb_epochs): + + if args.optim == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) + elif args.optim == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) + elif args.optim == 'adamw': + optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) + else: + raise ValueError(f'Unknown optimizer {args.optim}.') model.train() @@ -505,16 +505,19 @@ for k in range(nb_epochs_finished, nb_epochs): train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - log_string(f'perplexity {k} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') + log_string(f'perplexity {n_epoch} train_set {train_set_perplexity} train_prediction {train_perplexity} test_prediction {test_perplexity}') - task.produce_results(k, model) + task.produce_results(n_epoch, model) checkpoint = { - 'nb_epochs_finished': k + 1, + 'nb_epochs_finished': n_epoch + 1, 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict() + 'rng_state': torch.get_rng_state(), } + if torch.cuda.is_available(): + checkpoint['cuda_rng_state'] = torch.cuda.get_rng_state() + torch.save(checkpoint, args.checkpoint_name) ######################################################################