Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 15:58:38 +0000 (17:58 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 15:58:38 +0000 (17:58 +0200)
main.py

diff --git a/main.py b/main.py
index a7f9c9e..c4ecc49 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -497,7 +497,7 @@ def ae_generate(model, nb, local_device=main_device):
     all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
 
     for it in range(args.diffusion_nb_iterations):
-        log_string(f"nb_changed {all_changed.long().sum().item()}")
+        log_string(f"nb_changed {all_changed.long().sum().item()}")
 
         if not all_changed.any():
             break
@@ -892,9 +892,6 @@ if args.quizzes is not None:
 
 c_quizzes = None
 
-time_c_quizzes = 0
-time_train = 0
-
 ######################################################################
 
 
@@ -980,36 +977,33 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         nb_gpus = len(gpus)
         nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
-        (c_quizzes,) = multithread_execution(
+        (new_c_quizzes,) = multithread_execution(
             generate_c_quizzes,
             [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
         )
 
         save_quiz_image(
             models,
-            c_quizzes[:256],
+            new_c_quizzes[:256],
             f"culture_c_quiz_{n_epoch:04d}.png",
             solvable_only=False,
         )
 
         save_quiz_image(
             models,
-            c_quizzes[:256],
+            new_c_quizzes[:256],
             f"culture_c_quiz_{n_epoch:04d}_solvable.png",
             solvable_only=True,
         )
 
-        u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, :, 1:]
-        i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices
+        log_string(f"generated_c_quizzes {new_c_quizzes.size()=}")
 
-        save_quiz_image(
-            models,
-            c_quizzes[i][:256],
-            f"culture_c_quiz_{n_epoch:04d}_solvable_high_delta.png",
-            solvable_only=True,
+        c_quizzes = (
+            new_c_quizzes
+            if c_quizzes is None
+            else torch.cat([c_quizzes, new_c_quizzes])
         )
-
-        log_string(f"generated_c_quizzes {c_quizzes.size()=}")
+        c_quizzes = c_quizzes[-args.nb_train_samples :]
 
         for model in models:
             model.test_accuracy = 0