From 488d9393dba8dcb18f1a7f1bfb88b5e867267787 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 29 Jun 2024 23:26:52 +0300 Subject: [PATCH] Update. --- main.py | 2 +- wireworld.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index b62b4c0..a6c482f 100755 --- a/main.py +++ b/main.py @@ -224,7 +224,7 @@ assert args.nb_test_samples % args.batch_size == 0 if args.problem == "sky": problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=3) elif args.problem == "wireworld": - problem = wireworld.Wireworld(height=8, width=10, nb_iterations=4) + problem = wireworld.Wireworld(height=8, width=10, nb_iterations=2, speed=5) else: raise ValueError diff --git a/wireworld.py b/wireworld.py index 76c00e5..9e7d513 100755 --- a/wireworld.py +++ b/wireworld.py @@ -52,6 +52,15 @@ class Wireworld(problem.Problem): return self.token_forward, self.token_backward def generate_frame_sequences(self, nb): + result = [] + N = 100 + for _ in tqdm.tqdm( + range(0, nb + N, N), dynamic_ncols=True, desc="world generation" + ): + result.append(self.generate_frame_sequences_hard(100)) + return torch.cat(result, dim=0)[:nb] + + def generate_frame_sequences_hard(self, nb): frame_sequences = [] result = torch.full( @@ -152,7 +161,7 @@ class Wireworld(problem.Problem): i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 result = result[i] - print(f"{result.size(0)=} {nb=}") + # print(f"{result.size(0)=} {nb=}") if result.size(0) < nb: # print(result.size(0)) @@ -305,7 +314,7 @@ class Wireworld(problem.Problem): if __name__ == "__main__": import time - wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=5) + wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5) start_time = time.perf_counter() frame_sequences = wireworld.generate_frame_sequences(nb=96) -- 2.39.5