From aae01e186a959131b446d0365c6b951bacfd71d9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 29 Jun 2024 23:14:38 +0300 Subject: [PATCH] Update. --- main.py | 2 +- sky.py | 21 +++++++++++-------- wireworld.py | 57 +++++++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 62 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 590bfa1..b62b4c0 100755 --- a/main.py +++ b/main.py @@ -222,7 +222,7 @@ assert args.nb_train_samples % args.batch_size == 0 assert args.nb_test_samples % args.batch_size == 0 if args.problem == "sky": - problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2) + problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=3) elif args.problem == "wireworld": problem = wireworld.Wireworld(height=8, width=10, nb_iterations=4) else: diff --git a/sky.py b/sky.py index 6ba3882..1164185 100755 --- a/sky.py +++ b/sky.py @@ -118,8 +118,14 @@ class Sky(problem.Problem): dtype=torch.int64, ) + fine = torch.empty(self.nb_iterations * self.speed) + + t_to_keep = ( + torch.arange(self.nb_iterations, device=result.device) * self.speed + ) + for l in range(self.nb_iterations * self.speed): - fine = collision_okay() + fine[l] = collision_okay() for n in range(self.nb_birds): c = col[n] result[l, i[n], j[n]] = c @@ -139,14 +145,13 @@ class Sky(problem.Problem): i[n] += vi[n] j[n] += vj[n] - if fine: + result = result[t_to_keep] + fine = fine[t_to_keep] + + if fine[-1]: break - frame_sequences.append( - result[ - torch.arange(self.nb_iterations, device=result.device) * self.speed - ] - ) + frame_sequences.append(result) return frame_sequences @@ -296,7 +301,7 @@ class Sky(problem.Problem): if __name__ == "__main__": import time - sky = Sky(height=6, width=8, speed=2, nb_iterations=2) + sky = Sky(height=6, width=8, speed=4, nb_iterations=2) start_time = time.perf_counter() token_sequences = sky.generate_token_sequences(nb=64) diff --git a/wireworld.py b/wireworld.py index aff236d..76c00e5 100755 --- a/wireworld.py +++ b/wireworld.py @@ -55,7 +55,8 @@ class Wireworld(problem.Problem): frame_sequences = [] result = torch.full( - (nb * 4, self.nb_iterations, self.height, self.width), self.token_empty + (nb * 4, self.nb_iterations * self.speed, self.height, self.width), + self.token_empty, ) for n in range(result.size(0)): @@ -68,17 +69,52 @@ class Wireworld(problem.Problem): while True: if i < 0 or i >= self.height or j < 0 or j >= self.width: break + o = 0 + if i > 0: + o += (result[n, 0, i - 1, j] == self.token_conductor).long() + if i < self.height - 1: + o += (result[n, 0, i + 1, j] == self.token_conductor).long() + if j > 0: + o += (result[n, 0, i, j - 1] == self.token_conductor).long() + if j < self.width - 1: + o += (result[n, 0, i, j + 1] == self.token_conductor).long() + if o > 1: + break result[n, 0, i, j] = self.token_conductor i += vi j += vj - if torch.rand(1) < 0.5: + if ( + result[n, 0] == self.token_conductor + ).long().sum() > self.width and torch.rand(1) < 0.5: + break + + while True: + for _ in range(self.height * self.width): + i = torch.randint(self.height, (1,)) + j = torch.randint(self.width, (1,)) + v = torch.randint(2, (2,)) + vi = v[0] * (v[1] * 2 - 1) + vj = (1 - v[0]) * (v[1] * 2 - 1) + if ( + i + vi >= 0 + and i + vi < self.height + and j + vj >= 0 + and j + vj < self.width + and result[n, 0, i, j] == self.token_conductor + and result[n, 0, i + vi, j + vj] == self.token_conductor + ): + result[n, 0, i, j] = self.token_head + result[n, 0, i + vi, j + vj] = self.token_tail + break + + if torch.rand(1) < 0.75: break weight = torch.full((1, 1, 3, 3), 1.0) - mask = (torch.rand(result[:, 0].size()) < 0.01).long() - rand = torch.randint(4, mask.size()) - result[:, 0] = mask * rand + (1 - mask) * result[:, 0] + # mask = (torch.rand(result[:, 0].size()) < 0.01).long() + # rand = torch.randint(4, mask.size()) + # result[:, 0] = mask * rand + (1 - mask) * result[:, 0] # empty->empty # head->tail @@ -109,12 +145,15 @@ class Wireworld(problem.Problem): ) ) - i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 - result = result[ - torch.arange(self.nb_iterations, device=result.device) * self.speed + :, torch.arange(self.nb_iterations, device=result.device) * self.speed ] + i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 + result = result[i] + + print(f"{result.size(0)=} {nb=}") + if result.size(0) < nb: # print(result.size(0)) result = torch.cat( @@ -266,7 +305,7 @@ class Wireworld(problem.Problem): if __name__ == "__main__": import time - wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=1) + wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=5) start_time = time.perf_counter() frame_sequences = wireworld.generate_frame_sequences(nb=96) -- 2.39.5