From: François Fleuret Date: Wed, 3 Jul 2024 15:06:20 +0000 (+0300) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=5a77666812f943678094edea26bc17dff8304073;p=culture.git Update. --- diff --git a/lang.py b/lang.py index ce159e2..5adf50f 100755 --- a/lang.py +++ b/lang.py @@ -185,29 +185,45 @@ class Lang(problem.Problem): def task_replace_color(self, A, f_A, B, f_B): c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1 - i1, j1, i2, j2 = self.rec_coo(A) - A[i1:i2, j1:j2] = c1 - f_A[i1:i2, j1:j2] = c2 - for _ in range(3): - i1, j1, i2, j2 = self.rec_coo(B) - B[i1:i2, j1:j2] = c1 - f_B[i1:i2, j1:j2] = c2 - - def move_color(self, A, f_A, B, f_B): - c1, c2 = torch.randperm(len(self.colors) - 1)[:2] + 1 - - i1, j1, i2, j2 = self.rec_coo(A) - A[i1:i2, j1:j2] = c1 - f_A[i1:i2, j1:j2] = c1 - - while True: - i1, j1, i2, j2 = self.rec_coo(A) - if i2 < self.height - 1: - break - A[i1:i2, j1:j2] = c2 - f_A[i1:i2, j1:j2] = c2 + for n, X, f_X in [(1, A, f_A), (3, B, f_B)]: + for _ in range(torch.randint(n, (1,)) + 1): + i1, j1, i2, j2 = self.rec_coo(X) + X[i1:i2, j1:j2] = c1 + f_X[i1:i2, j1:j2] = c2 + + def task_move(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:1] + 1 + di, dj = torch.randint(2, (2,)) * 2 - 1 + for n, X, f_X in [(1, A, f_A), (3, B, f_B)]: + for _ in range(torch.randint(n, (1,)) + 1): + while True: + i1, j1, i2, j2 = self.rec_coo(X) + if ( + i1 + di >= 0 + and i2 + di < X.size(0) + and j1 + dj >= 0 + and j2 + dj < X.size(1) + ): + break + + X[i1:i2, j1:j2] = c + f_X[i1 + di : i2 + di, j1 + dj : j2 + dj] = c + + def task_grow(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:1] + 1 + + for n, X, f_X in [(1, A, f_A), (3, B, f_B)]: + for _ in range(torch.randint(n, (1,)) + 1): + while True: + i1, j1, i2, j2 = self.rec_coo(X) + if i1 + 3 < i2 and j1 + 3 < j2: + break + + X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = c + f_X[i1:i2, j1:j2] = c def generate_prompts_and_answers(self, nb): + tasks = [self.task_replace_color, self.task_move, self.task_grow] prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) w = self.width @@ -216,8 +232,7 @@ class Lang(problem.Problem): f_A = prompt[:, 1 * w : 2 * w] B = prompt[:, 2 * w : 3 * w] f_B = answer - # self.task_replace_color(A, f_A, B, f_B) - self.move_color(A, f_A, B, f_B) + tasks[torch.randint(len(tasks), (1,))](A, f_A, B, f_B) return prompts.flatten(1), answers.flatten(1) def save_quizzes( @@ -246,13 +261,16 @@ if __name__ == "__main__": lang = Lang(nb_iterations=4) - prompts, answers = lang.generate_prompts_and_answers(24) + prompts, answers = lang.generate_prompts_and_answers(36) predicted_prompts = torch.rand(prompts.size(0)) < 0.5 predicted_answers = torch.logical_not(predicted_prompts) lang.save_quizzes( - "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers + "/tmp", + "test", + prompts, + answers, # predicted_prompts, predicted_answers ) # start_time = time.perf_counter()