dtype=torch.int64,
)
+ fine = torch.empty(self.nb_iterations * self.speed)
+
+ t_to_keep = (
+ torch.arange(self.nb_iterations, device=result.device) * self.speed
+ )
+
for l in range(self.nb_iterations * self.speed):
- fine = collision_okay()
+ fine[l] = collision_okay()
for n in range(self.nb_birds):
c = col[n]
result[l, i[n], j[n]] = c
i[n] += vi[n]
j[n] += vj[n]
- if fine:
+ result = result[t_to_keep]
+ fine = fine[t_to_keep]
+
+ if fine[-1]:
break
- frame_sequences.append(
- result[
- torch.arange(self.nb_iterations, device=result.device) * self.speed
- ]
- )
+ frame_sequences.append(result)
return frame_sequences
if __name__ == "__main__":
import time
- sky = Sky(height=6, width=8, speed=2, nb_iterations=2)
+ sky = Sky(height=6, width=8, speed=4, nb_iterations=2)
start_time = time.perf_counter()
token_sequences = sky.generate_token_sequences(nb=64)
frame_sequences = []
result = torch.full(
- (nb * 4, self.nb_iterations, self.height, self.width), self.token_empty
+ (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
+ self.token_empty,
)
for n in range(result.size(0)):
while True:
if i < 0 or i >= self.height or j < 0 or j >= self.width:
break
+ o = 0
+ if i > 0:
+ o += (result[n, 0, i - 1, j] == self.token_conductor).long()
+ if i < self.height - 1:
+ o += (result[n, 0, i + 1, j] == self.token_conductor).long()
+ if j > 0:
+ o += (result[n, 0, i, j - 1] == self.token_conductor).long()
+ if j < self.width - 1:
+ o += (result[n, 0, i, j + 1] == self.token_conductor).long()
+ if o > 1:
+ break
result[n, 0, i, j] = self.token_conductor
i += vi
j += vj
- if torch.rand(1) < 0.5:
+ if (
+ result[n, 0] == self.token_conductor
+ ).long().sum() > self.width and torch.rand(1) < 0.5:
+ break
+
+ while True:
+ for _ in range(self.height * self.width):
+ i = torch.randint(self.height, (1,))
+ j = torch.randint(self.width, (1,))
+ v = torch.randint(2, (2,))
+ vi = v[0] * (v[1] * 2 - 1)
+ vj = (1 - v[0]) * (v[1] * 2 - 1)
+ if (
+ i + vi >= 0
+ and i + vi < self.height
+ and j + vj >= 0
+ and j + vj < self.width
+ and result[n, 0, i, j] == self.token_conductor
+ and result[n, 0, i + vi, j + vj] == self.token_conductor
+ ):
+ result[n, 0, i, j] = self.token_head
+ result[n, 0, i + vi, j + vj] = self.token_tail
+ break
+
+ if torch.rand(1) < 0.75:
break
weight = torch.full((1, 1, 3, 3), 1.0)
- mask = (torch.rand(result[:, 0].size()) < 0.01).long()
- rand = torch.randint(4, mask.size())
- result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
+ # mask = (torch.rand(result[:, 0].size()) < 0.01).long()
+ # rand = torch.randint(4, mask.size())
+ # result[:, 0] = mask * rand + (1 - mask) * result[:, 0]
# empty->empty
# head->tail
)
)
- i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
-
result = result[
- torch.arange(self.nb_iterations, device=result.device) * self.speed
+ :, torch.arange(self.nb_iterations, device=result.device) * self.speed
]
+ i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0
+ result = result[i]
+
+ print(f"{result.size(0)=} {nb=}")
+
if result.size(0) < nb:
# print(result.size(0))
result = torch.cat(
if __name__ == "__main__":
import time
- wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=1)
+ wireworld = Wireworld(height=10, width=15, nb_iterations=2, speed=5)
start_time = time.perf_counter()
frame_sequences = wireworld.generate_frame_sequences(nb=96)