From 391af504806952f7e9da52849167dfcfe8ab3036 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 16 Aug 2024 08:41:24 +0200 Subject: [PATCH] Update. --- main.py | 34 +++++++++++++++++++++++++++++----- quiz_machine.py | 13 +++++-------- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index dcd76ad..8802752 100755 --- a/main.py +++ b/main.py @@ -473,6 +473,12 @@ def save_additional_results(n_epoch, model, models, c_quizzes_procedure): + quiz_machine.models_logprobas( model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) ) + + quiz_machine.models_logprobas( + model, c_quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + quiz_machine.models_logprobas( + model, c_quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0) + ) for model in models ] @@ -516,11 +522,20 @@ def save_additional_results(n_epoch, model, models, c_quizzes_procedure): ###################################################################### -def model_proba_solutions(m, quizzes): - l = quiz_machine.models_logprobas( - m, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) - ) + quiz_machine.models_logprobas( - m, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) +def model_proba_solutions(model, quizzes): + l = ( + quiz_machine.models_logprobas( + model, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + quiz_machine.models_logprobas( + model, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + quiz_machine.models_logprobas( + model, quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0) + ) + + quiz_machine.models_logprobas( + model, quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0) + ) ) return l.exp() @@ -583,21 +598,27 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test): # the most consistent from a model which is confident for s in range(proba_own_solution.size(0)): + # At least one GPT does not understand at all if proba_own_solution[s, :].min() < args.proba_not_understands: dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands nb_fails = dont_get_this_quiz.long().sum() + # At most max_fail_to_validate do not understand (default 3/5) if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate: for model in models: + # If a GPT does not get that quiz if dont_get_this_quiz[model.id]: assert ( proba_own_solution[s, model.id] < args.proba_understands ) + # Look at its estimate of the others'solutions proba_other_solutions = model_proba_solutions( model, solved_c_quizzes[s] ) + # Randomize a bit the orders for the frequent P=1 proba_other_solutions += ( torch.rand(proba_other_solutions.size()) * 1e-6 ) + # Remove the under threshold confidence solutions proba_other_solutions[dont_get_this_quiz] = -1 i = proba_other_solutions.argmax() model.recorded_c_quizzes.append(solved_c_quizzes[s, i]) @@ -1121,6 +1142,9 @@ for n_epoch in range(current_epoch, args.nb_epochs): for model in models: if model.test_accuracy >= model.best_test_accuracy: + log_string( + f"storing_best model {model.id} accuracy {model.best_test_accuracy} -> {model.test_accuracy}" + ) model.best_dict = copy.deepcopy(model.state_dict()) model.best_test_accuracy = model.test_accuracy diff --git a/quiz_machine.py b/quiz_machine.py index 6da9075..3c4a865 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -85,14 +85,11 @@ class QuizMachine: self.train_structures = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), + # (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), + # (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), - # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), - # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)), - # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)), - # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), ] self.test_structures = self.train_structures @@ -198,7 +195,7 @@ class QuizMachine: input=result, ar_mask=ar_mask, seq_logprobas=seq_logprobas, - progress_bar_desc="accuracy", + progress_bar_desc="autoregression", ) correct = (result == quizzes).min(dim=1).values.long() -- 2.39.5