From 70a72d39c93dacc3cebc4dbc5d18bbe5289a4b86 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 2 Sep 2024 13:35:21 +0200 Subject: [PATCH] Update. --- main.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/main.py b/main.py index bb79484..b48d2a8 100755 --- a/main.py +++ b/main.py @@ -1319,6 +1319,14 @@ def c_quiz_criterion_one_good_one_bad(probas): return (probas.max(dim=1).values >= 0.75) & (probas.min(dim=1).values <= 0.25) +def c_quiz_criterion_one_good_no_very_bad(probas): + return ( + (probas.max(dim=1).values >= 0.75) + & (probas.min(dim=1).values <= 0.75) + & (probas.min(dim=1).values >= 0.25) + ) + + def c_quiz_criterion_diff(probas): return (probas.max(dim=1).values - probas.min(dim=1).values) >= 0.5 @@ -1328,6 +1336,11 @@ def c_quiz_criterion_diff2(probas): return (v[:, -2] - v[:, 0]) >= 0.5 +def c_quiz_criterion_only_one(probas): + v = probas.sort(dim=1).values + return (v[:, -1] >= 0.75) & (v[:, -2] <= 0.25) + + def c_quiz_criterion_two_good(probas): return ((probas >= 0.5).long().sum(dim=1) >= 2) & (probas.min(dim=1).values <= 0.2) @@ -1340,7 +1353,9 @@ def c_quiz_criterion_some(probas): def generate_ae_c_quizzes(models, local_device=main_device): criteria = [ + # c_quiz_criterion_only_one, c_quiz_criterion_one_good_one_bad, + # c_quiz_criterion_one_good_no_very_bad, # c_quiz_criterion_diff, # c_quiz_criterion_diff2, # c_quiz_criterion_two_good, -- 2.39.5