Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 15 Aug 2024 15:03:41 +0000 (17:03 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 15 Aug 2024 15:03:41 +0000 (17:03 +0200)
main.py

diff --git a/main.py b/main.py
index 78defa6..dcd76ad 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -101,7 +101,7 @@ parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
 parser.add_argument("--proba_understands", type=float, default=0.95)
 
-parser.add_argument("--proba_not_understands", type=float, default=0.5)
+parser.add_argument("--proba_not_understands", type=float, default=0.1)
 
 parser.add_argument("--temperature_hot", type=float, default=1.5)
 
@@ -583,26 +583,26 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
         # the most consistent from a model which is confident
 
         for s in range(proba_own_solution.size(0)):
-            dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
-            nb_fails = dont_get_this_quiz.long().sum()
-            if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate:
-                for model in models:
-                    if dont_get_this_quiz[model.id]:
-                        assert proba_own_solution[s, model.id] < args.proba_understands
-                        proba_other_solutions = model_proba_solutions(
-                            model, solved_c_quizzes[s]
-                        )
-
-                        # proba_other_solutions += torch.rand(proba_other_solutions.size()) * 1e-6
-
-                        proba_other_solutions[dont_get_this_quiz] = -1
-                        # print(
-                        # f"\nDEBUG {proba_own_solution[s,model.id]=} {proba_other_solutions=}\n"
-                        # )
-                        i = proba_other_solutions.argmax()
-                        model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
-                        teaching_count[i, model.id] += 1
-                        nb_validated += 1
+            if proba_own_solution[s, :].min() < args.proba_not_understands:
+                dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
+                nb_fails = dont_get_this_quiz.long().sum()
+                if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate:
+                    for model in models:
+                        if dont_get_this_quiz[model.id]:
+                            assert (
+                                proba_own_solution[s, model.id] < args.proba_understands
+                            )
+                            proba_other_solutions = model_proba_solutions(
+                                model, solved_c_quizzes[s]
+                            )
+                            proba_other_solutions += (
+                                torch.rand(proba_other_solutions.size()) * 1e-6
+                            )
+                            proba_other_solutions[dont_get_this_quiz] = -1
+                            i = proba_other_solutions.argmax()
+                            model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
+                            teaching_count[i, model.id] += 1
+                            nb_validated += 1
 
         duration = time.perf_counter() - start_time