From 03257cc01488588246fe23eabf54acaa2ac32442 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 15 Jul 2024 16:00:31 +0200 Subject: [PATCH] Update. --- grids.py | 4 ++-- main.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/grids.py b/grids.py index 7752136..a115f93 100755 --- a/grids.py +++ b/grids.py @@ -1126,7 +1126,7 @@ class Grids(problem.Problem): ) def save_some_examples(self, result_dir): - nb, nrow = 72, 4 + nb, nrow = 128, 4 for t in self.all_tasks: print(t.__name__) prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t]) @@ -1155,7 +1155,7 @@ if __name__ == "__main__": # exit(0) # if True: - nb, nrow = 72, 4 + nb, nrow = 128, 4 # nb, nrow = 8, 2 # for t in grids.all_tasks: diff --git a/main.py b/main.py index 4673f42..b372f12 100755 --- a/main.py +++ b/main.py @@ -372,12 +372,12 @@ def one_epoch(model, quiz_machine, local_device=main_device): # token_logprobas are NxMxT where M is the number of models +# def compute_valid_quizzes_(token_logprobas): +# warnings.warn("validation with uniform constraints", RuntimeWarning) +# l = token_logprobas.min(dim=-1).values.sort(dim=-1).values +# return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5)) - -def compute_valid_quizzes_(token_logprobas): - warnings.warn("validation with uniform constraints", RuntimeWarning) - l = token_logprobas.min(dim=-1).values.sort(dim=-1).values - return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5)) +# token_logprobas are NxMxT where M is the number of models def compute_valid_quizzes(token_logprobas): -- 2.39.5