From 5099d1d87e09a24859b70240477bbbcb6ac2f2a4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 19 Aug 2024 23:13:08 +0200 Subject: [PATCH] Update. --- main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 901e91c..19c8394 100755 --- a/main.py +++ b/main.py @@ -457,8 +457,8 @@ def model_modifier_cold(model): c_quizzes_procedure = [ # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot), - # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold), - (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot), + (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot), + # (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot), # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold), ] @@ -562,7 +562,6 @@ def create_c_quizzes( train_c_quiz_bags, nb_for_test, test_c_quiz_bags, - local_device=main_device, ): nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate @@ -871,7 +870,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): test_c_quiz_bags=test_c_quiz_bags, ) - c_quizzes = train_c_quiz_bags[-128:] + c_quizzes = train_c_quiz_bags[-1][:128] l = [model_proba_solutions(model, c_quizzes) for model in models] probas = torch.cat([x[:, None] for x in l], dim=1) comments = [] -- 2.39.5