From: François Fleuret Date: Thu, 8 Aug 2024 18:01:16 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=a2b35c224e66f7e17612c0e8de2462c9e998e051;p=culture.git Update. --- diff --git a/main.py b/main.py index 8bca425..3196fbd 100755 --- a/main.py +++ b/main.py @@ -363,10 +363,29 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### +def optimizer_to(optim, device): + for param in optim.state.values(): + # Not sure there are any global tensors in the state dict + if isinstance(param, torch.Tensor): + param.data = param.data.to(device) + if param._grad is not None: + param._grad.data = param._grad.data.to(device) + elif isinstance(param, dict): + for subparam in param.values(): + if isinstance(subparam, torch.Tensor): + subparam.data = subparam.data.to(device) + if subparam._grad is not None: + subparam._grad.data = subparam._grad.data.to(device) + + +###################################################################### + + def run_tests(model, quiz_machine, local_device=main_device): with torch.autograd.no_grad(): model.to(local_device).eval() - model.optimizer.eval() + if args.schedule_free: + model.optimizer.eval() nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 @@ -398,7 +417,10 @@ def run_tests(model, quiz_machine, local_device=main_device): def one_epoch(model, quiz_machine, local_device=main_device): model.to(local_device).train() - model.optimizer.train() + optimizer_to(model.optimizer, local_device) + + if args.schedule_free: + model.optimizer.train() nb_train_samples, acc_train_loss = 0, 0.0 @@ -454,6 +476,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): # ) model.to(main_device) + optimizer_to(model.optimizer, main_device) ######################################################################