Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:01:41 +0000 (10:01 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:01:41 +0000 (10:01 +0200)
main.py

diff --git a/main.py b/main.py
index 44035f9..84224e9 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -656,9 +656,6 @@ for i in range(args.nb_models):
         dropout=args.dropout,
     ).to(main_device)
 
-    # if i < args.nb_models//2:
-    # model = TokenCat(model, 10)
-
     # model = torch.compile(model)
 
     model.id = i
@@ -740,7 +737,7 @@ def generate_c_quizzes(models, nb, local_device=main_device):
 
     log_string(f"generate_c_quizz_speed {int(3600 * nb / duration)}/h")
 
-    return torch.cat(record)
+    return torch.cat(record).to("cpu")
 
 
 ######################################################################
@@ -876,6 +873,7 @@ time_train = 0
 
 
 def multithread_execution(fun, arguments):
+    # Single instance, no thread
     if len(arguments) == 1:
         return fun(*(arguments[0]))
 
@@ -954,7 +952,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         nb_gpus = len(gpus)
         nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
-        c_quizzes, agreements = multithread_execution(
+        c_quizzes = multithread_execution(
             generate_c_quizzes,
             [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
         )