From 5ecadfde470059278aec2b8ded217219e6773c04 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 3 Jul 2024 17:51:03 +0300 Subject: [PATCH] Update. --- lang.py | 46 ++++++++++++++++++++++++++++++++-------------- main.py | 6 +++--- quizz_machine.py | 2 ++ 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/lang.py b/lang.py index 3d939bb..ce159e2 100755 --- a/lang.py +++ b/lang.py @@ -66,6 +66,9 @@ class Lang(problem.Problem): predicted_prompts=None, predicted_answers=None, ): + prompts = prompts.reshape(prompts.size(0), self.height, -1) + answers = answers.reshape(answers.size(0), self.height, -1) + if predicted_prompts is None: predicted_prompts = 255 @@ -73,7 +76,6 @@ class Lang(problem.Problem): predicted_answers = 255 def add_frame(x, c, margin, bottom=False): - print(f"{type(x)=} {type(c)=}") if bottom: h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0 else: @@ -181,26 +183,42 @@ class Lang(problem.Problem): break return i1, j1, i2, j2 - def task_red_to_green(self, A, f_A, B, f_B): + def task_replace_color(self, A, f_A, B, f_B): + c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1 + i1, j1, i2, j2 = self.rec_coo(A) + A[i1:i2, j1:j2] = c1 + f_A[i1:i2, j1:j2] = c2 + for _ in range(3): + i1, j1, i2, j2 = self.rec_coo(B) + B[i1:i2, j1:j2] = c1 + f_B[i1:i2, j1:j2] = c2 + + def move_color(self, A, f_A, B, f_B): + c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1 + i1, j1, i2, j2 = self.rec_coo(A) - A[i1:i2, j1:j2] = self.name2color["red"] - f_A[i1:i2, j1:j2] = self.name2color["green"] - i1, j1, i2, j2 = self.rec_coo(B) - B[i1:i2, j1:j2] = self.name2color["red"] - f_B[i1:i2, j1:j2] = self.name2color["green"] + A[i1:i2, j1:j2] = c1 + f_A[i1:i2, j1:j2] = c1 + + while True: + i1, j1, i2, j2 = self.rec_coo(A) + if i2 < self.height - 1: + break + A[i1:i2, j1:j2] = c2 + f_A[i1:i2, j1:j2] = c2 def generate_prompts_and_answers(self, nb): prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width for prompt, answer in zip(prompts, answers): - self.task_red_to_green( - prompt[:, 0 * w : 1 * w], - prompt[:, 1 * w : 2 * w], - prompt[:, 2 * w : 3 * w], - answer, - ) - return prompts, answers + A = prompt[:, 0 * w : 1 * w] + f_A = prompt[:, 1 * w : 2 * w] + B = prompt[:, 2 * w : 3 * w] + f_B = answer + # self.task_replace_color(A, f_A, B, f_B) + self.move_color(A, f_A, B, f_B) + return prompts.flatten(1), answers.flatten(1) def save_quizzes( self, diff --git a/main.py b/main.py index a8a6191..b4e7318 100755 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ from torch.nn import functional as F import ffutils import mygpt -import sky, wireworld, quizz_machine +import sky, lang, quizz_machine # world quizzes vs. culture quizzes @@ -249,8 +249,8 @@ if args.problem == "sky": nb_iterations=args.sky_nb_iterations, speed=args.sky_speed, ) -elif args.problem == "wireworld": - problem = wireworld.Wireworld(height=8, width=10, nb_iterations=2, speed=5) +elif args.problem == "lang": + problem = lang.Lang(nb_iterations=2) else: raise ValueError diff --git a/quizz_machine.py b/quizz_machine.py index 90f288e..3828e5b 100755 --- a/quizz_machine.py +++ b/quizz_machine.py @@ -167,6 +167,8 @@ class QuizzMachine: def generate_token_sequences(self, nb): prompts, answers = self.problem.generate_prompts_and_answers(nb) + print(f"{prompts.size()=} {answers.size()=}") + if self.prompt_len is None: self.prompt_len = prompts.size(1) -- 2.39.5