From: François Fleuret Date: Wed, 3 Jul 2024 20:08:28 +0000 (+0300) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=167c56ace610c3b975c702203bb7c7ddf74930ae;p=culture.git Update. --- diff --git a/reasoning.py b/reasoning.py index b8d39ee..c442947 100755 --- a/reasoning.py +++ b/reasoning.py @@ -169,9 +169,13 @@ class Reasoning(problem.Problem): def nb_token_values(self): return len(self.colors) + # That's quite a tensorial spaghetti mess to sample + # non-overlapping rectangles quickly, but made the generation of + # 100k samples from 1h50 with a lame pure python code to 4min with + # this one. def rec_coo(self, x, n, min_height=3, min_width=3): K = 3 - N = 4000 + N = 1000 while True: v = ( @@ -365,12 +369,8 @@ class Reasoning(problem.Problem): self.task_frame, self.task_detect, ] - prompts = torch.zeros( - nb, self.height, self.width * 3, dtype=torch.int64, device=self.device - ) - answers = torch.zeros( - nb, self.height, self.width, dtype=torch.int64, device=self.device - ) + prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) + answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width for prompt, answer in tqdm.tqdm( @@ -385,6 +385,7 @@ class Reasoning(problem.Problem): f_B = answer task = tasks[torch.randint(len(tasks), (1,))] task(A, f_A, B, f_B) + return prompts.flatten(1), answers.flatten(1) def save_quizzes(