Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 06:27:36 +0000 (08:27 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 17 Sep 2024 06:27:36 +0000 (08:27 +0200)
main.py

diff --git a/main.py b/main.py
index 899a099..e921ccd 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -804,14 +804,14 @@ def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
 def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     # train
 
-    one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True)
-    one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False)
+    one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True)
+    one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False)
 
     # predict
 
     quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
     input, targets, mask = batch_prediction(quizzes.to(local_device))
-    result = predict(model, input, targets, mask).to("cpu")
+    result = predict(model, input, targets, mask, local_device=local_device).to("cpu")
     mask = mask.to("cpu")
     correct = (quizzes == result).min(dim=1).values.long()
     correct_parts = (2 * correct - 1)[:, None] * mask.reshape(mask.size(0), 4, -1)[
@@ -830,7 +830,7 @@ def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
     # generate
 
-    result = generate(model, 25).to("cpu")
+    result = generate(model, 25, local_device=local_device).to("cpu")
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
         f"culture_generation_{n_epoch}_{model.id}.png",