From: François Fleuret Date: Mon, 19 Aug 2024 21:13:08 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=5099d1d87e09a24859b70240477bbbcb6ac2f2a4;p=culture.git Update. --- diff --git a/main.py b/main.py index 901e91c..19c8394 100755 --- a/main.py +++ b/main.py @@ -457,8 +457,8 @@ def model_modifier_cold(model): c_quizzes_procedure = [ # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot), - # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold), - (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot), + (("f_B", "f_A", "A", "B"), (1, 1, 1, 1), model_modifier_hot), + # (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), model_modifier_hot), # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold), ] @@ -562,7 +562,6 @@ def create_c_quizzes( train_c_quiz_bags, nb_for_test, test_c_quiz_bags, - local_device=main_device, ): nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate @@ -871,7 +870,7 @@ for n_epoch in range(current_epoch, args.nb_epochs): test_c_quiz_bags=test_c_quiz_bags, ) - c_quizzes = train_c_quiz_bags[-128:] + c_quizzes = train_c_quiz_bags[-1][:128] l = [model_proba_solutions(model, c_quizzes) for model in models] probas = torch.cat([x[:, None] for x in l], dim=1) comments = []