From: François Fleuret Date: Fri, 12 Jul 2024 08:08:05 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f2ab5fd489adebe9b34ac825d39e41f13f287cb2;p=culture.git Update. --- diff --git a/grids.py b/grids.py index cfc7d16..5dad6f3 100755 --- a/grids.py +++ b/grids.py @@ -37,11 +37,34 @@ class Grids(problem.Problem): max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1, + tasks=None, ): self.colors = torch.tensor([c for _, c in self.named_colors]) self.height = 10 self.width = 10 self.cache_rec_coo = {} + + all_tasks = [ + self.task_replace_color, + self.task_translate, + self.task_grow, + self.task_half_fill, + self.task_frame, + self.task_detect, + self.task_count, + self.task_trajectory, + self.task_bounce, + self.task_scale, + self.task_symbols, + self.task_isometry, + # self.task_path, + ] + + if tasks is None: + self.all_tasks = all_tasks + else: + self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")] + super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) ###################################################################### @@ -398,7 +421,7 @@ class Grids(problem.Problem): f_X[i1:i2, j1:j2] = c[n] # @torch.compile - def task_color_grow(self, A, f_A, B, f_B): + def task_half_fill(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1 @@ -715,7 +738,7 @@ class Grids(problem.Problem): f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] # @torch.compile - def task_ortho(self, A, f_A, B, f_B): + def task_isometry(self, A, f_A, B, f_B): nb_rec = 3 di, dj = torch.randint(3, (2,)) - 1 o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]]) @@ -939,23 +962,6 @@ class Grids(problem.Problem): ###################################################################### - def all_tasks(self): - return [ - self.task_replace_color, - self.task_translate, - self.task_grow, - self.task_color_grow, - self.task_frame, - self.task_detect, - self.task_count, - self.task_trajectory, - self.task_bounce, - self.task_scale, - self.task_symbols, - self.task_ortho, - # self.task_path, - ] - def trivial_prompts_and_answers(self, prompts, answers): S = self.height * self.width Bs = prompts[:, 2 * (S + 1) : 2 * (S + 1) + S] @@ -964,7 +970,7 @@ class Grids(problem.Problem): def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False): if tasks is None: - tasks = self.all_tasks() + tasks = self.all_tasks S = self.height * self.width prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64) @@ -1012,7 +1018,7 @@ class Grids(problem.Problem): def save_some_examples(self, result_dir): nb, nrow = 72, 4 - for t in self.all_tasks(): + for t in self.all_tasks: print(t.__name__) prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t]) self.save_quizzes( @@ -1043,7 +1049,7 @@ if __name__ == "__main__": nb, nrow = 72, 4 # nb, nrow = 8, 2 - # for t in grids.all_tasks(): + # for t in grids.all_tasks: for t in [ grids.task_replace_color, grids.task_frame, @@ -1056,8 +1062,8 @@ if __name__ == "__main__": nb = 1000 - for t in grids.all_tasks(): - # for t in [ grids.task_replace_color ]: #grids.all_tasks(): + for t in grids.all_tasks: + # for t in [ grids.task_replace_color ]: #grids.all_tasks: start_time = time.perf_counter() prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) delay = time.perf_counter() - start_time diff --git a/main.py b/main.py index b88cbc4..fc55b9c 100755 --- a/main.py +++ b/main.py @@ -98,6 +98,19 @@ parser.add_argument("--dirty_debug", action="store_true", default=False) ###################################################################### +grids_tasks = ", ".join( + [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks] +) + +parser.add_argument( + "--grids_tasks", + type=str, + default=None, + help="A comma-separated subset of: " + grids_tasks + ", or None for all.", +) + +###################################################################### + parser.add_argument("--sky_height", type=int, default=6) parser.add_argument("--sky_width", type=int, default=8) @@ -250,6 +263,7 @@ elif args.problem == "grids": max_nb_cached_chunks=args.nb_gpus * args.nb_train_samples // 100, chunk_size=100, nb_threads=args.nb_threads, + tasks=args.grids_tasks, ) back_accuracy = True else: diff --git a/quiz_machine.py b/quiz_machine.py index 4f704a0..631d41b 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -416,7 +416,7 @@ class QuizMachine: def logproba_of_solutions(self, models, c_quizzes): logproba = c_quizzes.new_zeros( - c_quizzes.size(0), len(models), device=self.device + c_quizzes.size(0), len(models), device=self.device, dtype=torch.float32 ) for model in models: