Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 21:13:08 +0000 (23:13 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 21:13:08 +0000 (23:13 +0200)
main.py

diff --git a/main.py b/main.py
index 901e91c..19c8394 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -457,8 +457,8 @@ def model_modifier_cold(model):
 
 c_quizzes_procedure = [
     # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
-    # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
-    (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot),
+    (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot),
+    (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot),
     # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
 ]
 
@@ -562,7 +562,6 @@ def create_c_quizzes(
     train_c_quiz_bags,
     nb_for_test,
     test_c_quiz_bags,
-    local_device=main_device,
 ):
     nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test
     nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
@@ -871,7 +870,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
             test_c_quiz_bags=test_c_quiz_bags,
         )
 
-        c_quizzes = train_c_quiz_bags[-128:]
+        c_quizzes = train_c_quiz_bags[-1][:128]
         l = [model_proba_solutions(model, c_quizzes) for model in models]
         probas = torch.cat([x[:, None] for x in l], dim=1)
         comments = []