From: François Fleuret Date: Mon, 19 Jun 2023 21:53:59 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=c253ddd45b809dd7389773f690a83366a17ccde6;p=picoclvr.git Update. --- diff --git a/main.py b/main.py index 3db87df..f8e451b 100755 --- a/main.py +++ b/main.py @@ -623,6 +623,71 @@ class TaskMaze(Task): ###################################################################### + +def generate_snake_sequences( + nb, height, width, nb_colors, length, device=torch.device("cpu") +): + world = torch.randint(nb_colors, (nb, height, width), device=device) + # nb x 2 + snake_position = torch.cat( + ( + torch.randint(height, (nb, 1), device=device), + torch.randint(width, (nb, 1), device=device), + ), + 1, + ) + snake_direction = torch.randint(4, (nb, 1), device=device) + result = torch.empty(nb, 2*length, device=device, dtype=torch.int64) + count = torch.arange(nb, device=device) # [:,None] + + for l in range(length): + # nb x 3 + snake_next_direction = torch.cat( + ( + (snake_direction - 1) % 4, + snake_direction, + (snake_direction + 1) % 4, + ), + 1, + ) + + # nb x 3 + vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1) + vw = snake_next_direction % 2 * (snake_next_direction - 2) + + # nb x 3 x 2 + snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2) + snake_next_position = snake_position[:, None, :] + snake_next_speed + + # nb x 3 + val = torch.logical_and( + torch.logical_and( + snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height + ), + torch.logical_and( + snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width + ), + ).float() + val = torch.rand_like(val) * val * torch.tensor([[1.,4.,1.]], device=device) + + # nb + i = torch.arange(val.size(0), device=device) + j = val.argmax(1) + + # nb x 1 + snake_direction = snake_next_direction[i[:, None], j[:, None]] + + result[:, 2*l] = world[count, snake_position[:, 0], snake_position[:, 1]] + result[:, 2*l+1] = snake_direction[:,0] + + # nb x 2 + snake_position = snake_next_position[i[:, None], j[:, None]].squeeze(1) + + return result + +generate_snake_sequences(nb=2, height=4, width=5, nb_colors=3, length=10) +exit(0) + class TaskSnake(Task): def __init__( self, @@ -631,7 +696,8 @@ class TaskSnake(Task): batch_size, height, width, - nb_walls, + nb_colors, + length, device=torch.device("cpu"), ): self.batch_size = batch_size @@ -639,10 +705,14 @@ class TaskSnake(Task): self.width = width self.device = device - # self.train_input = - # self.test_input = + self.train_input = generate_snake_sequences( + nb_train_samples, height, width, nb_colors, length, self.device + ) + self.test_input = generate_snake_sequences( + nb_test_samples, height, width, nb_colors, length, self.device + ) - self.nb_codes = max(self.train_input.max(), self.train_input.max()) + 1 + self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -656,6 +726,9 @@ class TaskSnake(Task): ): yield batch + def vocabulary_size(self): + return self.nb_codes + ###################################################################### @@ -708,6 +781,18 @@ elif args.task == "maze": device=device, ) +elif args.task == "snake": + task = TaskSnake( + nb_train_samples=args.nb_train_samples, + nb_test_samples=args.nb_test_samples, + batch_size=args.batch_size, + height=6, + width=8, + nb_colors=5, + length=100, + device=device, + ) + else: raise ValueError(f"Unknown task {args.task}")