From: François Fleuret Date: Sat, 29 Jun 2024 16:00:48 +0000 (+0300) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=29679cb42710602037fee650a5672f01a3292077;p=culture.git Update. --- diff --git a/main.py b/main.py index 30dcd4d..590bfa1 100755 --- a/main.py +++ b/main.py @@ -224,7 +224,7 @@ assert args.nb_test_samples % args.batch_size == 0 if args.problem == "sky": problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2) elif args.problem == "wireworld": - problem = wireworld.Wireworld(height=10, width=15, nb_iterations=4) + problem = wireworld.Wireworld(height=8, width=10, nb_iterations=4) else: raise ValueError diff --git a/sky.py b/sky.py index abcd394..6ba3882 100755 --- a/sky.py +++ b/sky.py @@ -112,10 +112,13 @@ class Sky(problem.Problem): break result = torch.zeros( - self.nb_iterations, self.height, self.width, dtype=torch.int64 + self.nb_iterations * self.speed, + self.height, + self.width, + dtype=torch.int64, ) - for l in range(self.nb_iterations): + for l in range(self.nb_iterations * self.speed): fine = collision_okay() for n in range(self.nb_birds): c = col[n] @@ -139,7 +142,11 @@ class Sky(problem.Problem): if fine: break - frame_sequences.append(result) + frame_sequences.append( + result[ + torch.arange(self.nb_iterations, device=result.device) * self.speed + ] + ) return frame_sequences diff --git a/wireworld.py b/wireworld.py index 219d7dd..aff236d 100755 --- a/wireworld.py +++ b/wireworld.py @@ -38,11 +38,14 @@ class Wireworld(problem.Problem): "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" ) - def __init__(self, height=6, width=8, nb_objects=2, nb_walls=2, nb_iterations=4): + def __init__( + self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4 + ): self.height = height self.width = width self.nb_objects = nb_objects self.nb_walls = nb_walls + self.speed = speed self.nb_iterations = nb_iterations def direction_tokens(self): @@ -82,7 +85,7 @@ class Wireworld(problem.Problem): # tail->conductor # conductor->head if 1 or 2 head in the neighborhood, or remains conductor - for l in range(self.nb_iterations - 1): + for l in range(self.nb_iterations * self.speed - 1): nb_head_neighbors = ( F.conv2d( input=(result[:, l] == self.token_head).float()[:, None, :, :], @@ -108,7 +111,9 @@ class Wireworld(problem.Problem): i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 - result = result[i] + result = result[ + torch.arange(self.nb_iterations, device=result.device) * self.speed + ] if result.size(0) < nb: # print(result.size(0)) @@ -116,7 +121,7 @@ class Wireworld(problem.Problem): [result, self.generate_frame_sequences(nb - result.size(0))], dim=0 ) - return result + return result[:nb] def generate_token_sequences(self, nb): frame_sequences = self.generate_frame_sequences(nb) @@ -261,7 +266,7 @@ class Wireworld(problem.Problem): if __name__ == "__main__": import time - wireworld = Wireworld(height=10, width=15, nb_iterations=4) + wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=1) start_time = time.perf_counter() frame_sequences = wireworld.generate_frame_sequences(nb=96) @@ -270,19 +275,21 @@ if __name__ == "__main__": # print(wireworld.seq2str(seq[:4])) - for t in range(frame_sequences.size(1)): - img = wireworld.seq2img(frame_sequences[:, t]) - torchvision.utils.save_image( - img.float() / 255.0, - f"/tmp/frame_{t:03d}.png", - nrow=8, - padding=6, - pad_value=0, - ) + # for t in range(frame_sequences.size(1)): + # img = wireworld.seq2img(frame_sequences[:, t]) + # torchvision.utils.save_image( + # img.float() / 255.0, + # f"/tmp/frame_{t:03d}.png", + # nrow=8, + # padding=6, + # pad_value=0, + # ) # m = (torch.rand(seq.size()) < 0.05).long() # seq = (1 - m) * seq + m * 23 + token_sequences = wireworld.generate_token_sequences(32) + wireworld.save_quizzes(token_sequences, "/tmp", "seq") # img = wireworld.seq2img(frame_sequences[:60]) # torchvision.utils.save_image(