Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 20:36:22 +0000 (22:36 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 11 Aug 2024 20:36:22 +0000 (22:36 +0200)
main.py

diff --git a/main.py b/main.py
index 40772c2..cd6e3a9 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -78,7 +78,7 @@ parser.add_argument("--nb_heads", type=int, default=None)
 
 parser.add_argument("--nb_blocks", type=int, default=None)
 
-parser.add_argument("--dropout", type=float, default=0.1)
+parser.add_argument("--dropout", type=float, default=0.5)
 
 # ----------------------------------
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
@@ -93,13 +93,15 @@ parser.add_argument("--gpus", type=str, default="all")
 
 parser.add_argument("--nb_gpts", type=int, default=5)
 
+parser.add_argument("--min_succeed_to_validate", type=int, default=2)
+
 parser.add_argument("--max_fail_to_validate", type=int, default=3)
 
 parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95)
 
 parser.add_argument("--proba_understands", type=float, default=0.95)
 
-parser.add_argument("--proba_not_understands", type=float, default=0.1)
+parser.add_argument("--proba_not_understands", type=float, default=0.5)
 
 parser.add_argument("--temperature_hot", type=float, default=1.5)
 
@@ -663,7 +665,8 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
         nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
 
         to_keep = (
-            (nb_succeed + nb_fail == probas.size(1))
+            # (nb_succeed + nb_fail == probas.size(1))
+            (nb_succeed >= args.min_succeed_to_validate)
             & (nb_fail >= 1)
             & (nb_fail <= args.max_fail_to_validate)
         )