Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:21:53 +0000 (09:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:21:53 +0000 (09:21 +0200)
main.py

diff --git a/main.py b/main.py
index 05bb108..22854a9 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -589,7 +589,8 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=True)
     one_epoch(model, n_epoch, c_quizzes, local_device=local_device, train=False)
 
-    #!!!!!!!!!!!!!!!!!!!!!!!!!
+    # Save some original world quizzes and the full prediction (the four grids)
+
     quizzes = quiz_machine.quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(
         local_device
     )
@@ -600,11 +601,12 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
     )
-    #!!!!!!!!!!!!!!!!!!!!!!!!!
 
-    # predict
+    # Save some images of the prediction results (one grid at random)
 
-    quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
+    quizzes = quiz_machine.quiz_set(
+        args.nb_test_samples, c_quizzes, args.c_quiz_multiplier
+    )
     imt_set = batch_prediction_imt(quizzes.to(local_device))
     result = predict(model, imt_set, local_device=local_device).to("cpu")
     masks = imt_set[:, 1].to("cpu")
@@ -623,6 +625,8 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
         correct_parts=correct_parts[:128],
     )
 
+    # Compute the test accuracy
+
     nb_correct, nb_total = correct.sum(), quizzes.size(0)
     model.test_accuracy = nb_correct / nb_total
 
@@ -630,7 +634,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
         f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({model.test_accuracy:.02f}%)"
     )
 
-    # generate
+    # Save some images of the ex nihilo generation of the four grids
 
     result = generate(model, 150, local_device=local_device).to("cpu")
     quiz_machine.problem.save_quizzes_as_image(