Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:27:44 +0000 (09:27 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 07:27:44 +0000 (09:27 +0200)
main.py

diff --git a/main.py b/main.py
index 22854a9..380be1e 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -67,7 +67,7 @@ parser.add_argument("--nb_train_alien_samples", type=int, default=0)
 
 parser.add_argument("--nb_test_alien_samples", type=int, default=0)
 
-parser.add_argument("--nb_c_quizzes", type=int, default=10000)
+parser.add_argument("--nb_c_quizzes", type=int, default=2500)
 
 parser.add_argument("--c_quiz_multiplier", type=int, default=1)
 
@@ -710,10 +710,10 @@ def generate_c_quizzes(models, nb, local_device=main_device):
             nb_wrong >= args.nb_have_to_be_wrong
         )
 
-        nb_validated += to_keep.long().sum()
+        nb_validated += to_keep.long().sum().item()
         record.append(c_quizzes[to_keep])
 
-        log_string(f"generate_c_quizzes {nb_validated}")
+        log_string(f"generate_c_quizzes {nb_validated}")
 
         #####################
 
@@ -722,8 +722,8 @@ def generate_c_quizzes(models, nb, local_device=main_device):
         if last_log < 0 or duration > last_log + 10:
             last_log = duration
             if nb_validated > 0:
-                if nb_validated < wanted_nb:
-                    d = (wanted_nb - nb_validated) * duration / nb_validated
+                if nb_validated < nb:
+                    d = (nb - nb_validated) * duration / nb_validated
                     e = (
                         datetime.datetime.now() + datetime.timedelta(seconds=d)
                     ).strftime("%a %H:%M")
@@ -740,7 +740,7 @@ def generate_c_quizzes(models, nb, local_device=main_device):
 
     duration = time.perf_counter() - start_time
 
-    log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
+    log_string(f"generate_c_quizz_speed {int(3600 * nb / duration)}/h")
 
     return torch.cat(record)