Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 20:26:52 +0000 (23:26 +0300)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 29 Jun 2024 20:26:52 +0000 (23:26 +0300)
main.py
wireworld.py

diff --git a/main.py b/main.py
index b62b4c0..a6c482f 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -224,7 +224,7 @@ assert args.nb_test_samples % args.batch_size == 0
 if args.problem == "sky":
     problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=3)
 elif args.problem == "wireworld":
-    problem = wireworld.Wireworld(height=8, width=10, nb_iterations=4)
+    problem = wireworld.Wireworld(height=8, width=10, nb_iterations=2, speed=5)
 else:
     raise ValueError
 
index 76c00e5..9e7d513 100755 (executable)
@@ -52,6 +52,15 @@ class Wireworld(problem.Problem):
         return self.token_forward, self.token_backward
 
     def generate_frame_sequences(self, nb):
+        result = []
+        N = 100
+        for _ in tqdm.tqdm(
+            range(0, nb + N, N), dynamic_ncols=True, desc="world generation"
+        ):
+            result.append(self.generate_frame_sequences_hard(100))
+        return torch.cat(result, dim=0)[:nb]
+
+    def generate_frame_sequences_hard(self, nb):
         frame_sequences = []
 
         result = torch.full(
@@ -152,7 +161,7 @@ class Wireworld(problem.Problem):
         i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
         result = result[i]
 
-        print(f"{result.size(0)=} {nb=}")
+        print(f"{result.size(0)=} {nb=}")
 
         if result.size(0) < nb:
             # print(result.size(0))
@@ -305,7 +314,7 @@ class Wireworld(problem.Problem):
 if __name__ == "__main__":
     import time
 
-    wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=5)
+    wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5)
 
     start_time = time.perf_counter()
     frame_sequences = wireworld.generate_frame_sequences(nb=96)