if l > 3:
break
+ def task_scale(self, A, f_A, B, f_B):
+ c = torch.randperm(len(self.colors) - 1)[:2] + 1
+
+ i, j = torch.randint(self.height // 2, (1,)), torch.randint(
+ self.width // 2, (1,)
+ )
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ for _ in range(3):
+ while True:
+ i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
+ self.width // 2 + 1, (1,)
+ )
+ i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
+ self.width // 2 + 1, (1,)
+ )
+ if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
+ break
+ X[i + i1 : i + i2, j + j1 : j + j2] = c[0]
+ f_X[2 * i1 : 2 * i2, 2 * j1 : 2 * j2] = c[0]
+
+ X[i, j] = c[1]
+ f_X[0:2, 0:2] = c[1]
+
######################################################################
def generate_prompts_and_answers(self, nb, device="cpu"):
self.task_count,
self.task_trajectory,
self.task_bounce,
+ self.task_scale,
]
prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64)
answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64)