Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:21:02 +0000 (10:21 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 18 Sep 2024 08:21:02 +0000 (10:21 +0200)
main.py

diff --git a/main.py b/main.py
index 84224e9..86a3ae9 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -670,6 +670,26 @@ for i in range(args.nb_models):
 ######################################################################
 
 
+def evaluate_quizzes(c_quizzes, models, local_device):
+    nb_correct, nb_wrong = 0, 0
+
+    for model in models:
+        model = copy.deepcopy(model).to(local_device).eval()
+        result = predict_full(model, c_quizzes, local_device=local_device)
+        nb_mistakes = (result != c_quizzes).long().sum(dim=1)
+        nb_correct += (nb_mistakes == 0).long()
+        nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
+
+    to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
+        nb_wrong >= args.nb_have_to_be_wrong
+    )
+
+    return to_keep, nb_correct, nb_wrong
+
+
+######################################################################
+
+
 def generate_c_quizzes(models, nb, local_device=main_device):
     record = []
     nb_validated = 0
@@ -694,17 +714,8 @@ def generate_c_quizzes(models, nb, local_device=main_device):
         # Select the ones that are solved properly by some models and
         # not understood by others
 
-        nb_correct, nb_wrong = 0, 0
-
-        for i, model in enumerate(models):
-            model = copy.deepcopy(model).to(local_device).eval()
-            result = predict_full(model, c_quizzes, local_device=local_device)
-            nb_mistakes = (result != c_quizzes).long().sum(dim=1)
-            nb_correct += (nb_mistakes == 0).long()
-            nb_wrong += nb_mistakes >= args.nb_mistakes_to_be_wrong
-
-        to_keep = (nb_correct >= args.nb_have_to_be_correct) & (
-            nb_wrong >= args.nb_have_to_be_wrong
+        to_keep, nb_correct, nb_wrong = evaluate_quizzes(
+            c_quizzes, models, local_device
         )
 
         nb_validated += to_keep.long().sum().item()
@@ -743,31 +754,19 @@ def generate_c_quizzes(models, nb, local_device=main_device):
 ######################################################################
 
 
-def save_c_quizzes_with_scores(models, c_quizzes, filename, solvable_only=False):
-    l = []
-
-    c_quizzes = c_quizzes.to(main_device)
-
-    with torch.autograd.no_grad():
-        to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
-            models,
-            c_quizzes,
-            main_device,
-            nb_have_to_be_correct=args.nb_have_to_be_correct,
-            nb_have_to_be_wrong=0,
-            nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
-            nb_hints=None,
-        )
+def save_quiz_image(
+    models, c_quizzes, filename, solvable_only=False, local_device=main_device
+):
+    c_quizzes = c_quizzes.to(local_device)
 
-        if solvable_only:
-            c_quizzes = c_quizzes[to_keep]
-            nb_correct = nb_correct[to_keep]
-            nb_wrong = nb_wrong[to_keep]
+    to_keep, nb_correct, nb_wrong = evaluate_quizzes(c_quizzes, models, local_device)
 
-        comments = []
+    if solvable_only:
+        c_quizzes = c_quizzes[to_keep]
+        nb_correct = nb_correct[to_keep]
+        nb_wrong = nb_wrong[to_keep]
 
-        for c, w in zip(nb_correct, nb_wrong):
-            comments.append(f"nb_correct {c} nb_wrong {w}")
+    comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
 
     quiz_machine.problem.save_quizzes_as_image(
         args.result_dir,
@@ -880,7 +879,10 @@ def multithread_execution(fun, arguments):
     records, threads = [], []
 
     def threadable_fun(*args):
-        records.append(fun(*args))
+        r = fun(*args)
+        if type(r) is not tuple:
+            r = (r,)
+        records.append(r)
 
     for args in arguments:
         # To get a different sequence between threads
@@ -952,19 +954,19 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         nb_gpus = len(gpus)
         nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
 
-        c_quizzes = multithread_execution(
+        (c_quizzes,) = multithread_execution(
             generate_c_quizzes,
             [(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
         )
 
-        save_c_quizzes_with_scores(
+        save_quiz_image(
             models,
             c_quizzes[:256],
             f"culture_c_quiz_{n_epoch:04d}.png",
             solvable_only=False,
         )
 
-        save_c_quizzes_with_scores(
+        save_quiz_image(
             models,
             c_quizzes[:256],
             f"culture_c_quiz_{n_epoch:04d}_solvable.png",
@@ -974,7 +976,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, :, 1:]
         i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices
 
-        save_c_quizzes_with_scores(
+        save_quiz_image(
             models,
             c_quizzes[i][:256],
             f"culture_c_quiz_{n_epoch:04d}_solvable_high_delta.png",