From 276d3ec2f05b3e7061cb8389eb528719084a3905 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 10 Sep 2024 09:05:29 +0200 Subject: [PATCH] Update. --- main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 97d37ce..b7050df 100755 --- a/main.py +++ b/main.py @@ -1230,7 +1230,9 @@ if args.resume: filename = f"ae_{model.id:03d}.pth" try: - d = torch.load(os.path.join(args.result_dir, filename)) + d = torch.load( + os.path.join(args.result_dir, filename), map_location=main_device + ) model.load_state_dict(d["state_dict"]) model.optimizer.load_state_dict(d["optimizer_state_dict"]) model.test_accuracy = d["test_accuracy"] @@ -1378,9 +1380,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): else: records.append( - generate_ae_c_quizzes( - models, nb_c_quizzes_to_generate, records, gpus[0] - ) + generate_ae_c_quizzes(models, nb_c_quizzes_to_generate, gpus[0]) ) time_c_quizzes = int(time.perf_counter() - start_time) -- 2.39.5