From 6efd95bf99834bbf42b4326063a750030033ad7d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 5 Sep 2024 14:30:00 +0200 Subject: [PATCH] Update. --- grids.py | 2 +- main.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/grids.py b/grids.py index 9372922..2717b22 100755 --- a/grids.py +++ b/grids.py @@ -1754,7 +1754,7 @@ class Grids(problem.Problem): return quizzes def save_some_examples(self, result_dir, prefix=""): - nb, nrow = 128, 4 + nb, nrow = 256, 8 for t in self.all_tasks: print(t.__name__) quizzes = self.generate_w_quizzes_(nb, tasks=[t]) diff --git a/main.py b/main.py index e95c4f6..8562e32 100755 --- a/main.py +++ b/main.py @@ -107,7 +107,7 @@ parser.add_argument("--min_succeed_to_validate", type=int, default=2) parser.add_argument("--max_fail_to_validate", type=int, default=3) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.98) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.95) parser.add_argument("--proba_understands", type=float, default=0.95) @@ -1198,6 +1198,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): while bag_len(records) < wanted_nb: model = copy_for_inference(models[torch.randint(len(models), (1,)).item()]) + generator_id = model.id c_quizzes = ae_generate(model, template, mask_generate) @@ -1248,7 +1249,7 @@ def generate_ae_c_quizzes(models, nb, local_device=main_device): e = "???" log_string( - f"nb_generated {bag_len(records)} model {model.id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)" + f"nb_generated {bag_len(records)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)" ) duration = time.perf_counter() - start_time -- 2.39.5