From: François Fleuret Date: Tue, 31 Oct 2023 08:14:35 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=a2ccdd2f5e9fb3e7ed52492729b880f815ddfbcb;p=pytorch.git Update. --- diff --git a/picocrafter.py b/picocrafter.py new file mode 100755 index 0000000..33a00c1 --- /dev/null +++ b/picocrafter.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python + +######################################################################### +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the version 3 of the GNU General Public License # +# as published by the Free Software Foundation. # +# # +# This program is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +# Written by and Copyright (C) Francois Fleuret # +# Contact for comments & bug reports # +######################################################################### + +# This is a tiny rogue-like environment implemented with tensor +# operations, that runs in batches efficiently on a GPU. On a RTX4090 +# it can initialize ~20k environments per second and run ~40k +# iterations. +# +# The agent "@" moves in a maze-like grid with random walls "#". There +# are five actions: move NESW or do not move. +# +# There are monsters "$" moving randomly. The agent gets hit by every +# monster present in one of the 4 direct neighborhoods at the end of +# the moves, each hit results in a rewards of -1. +# +# The agent starts with 5 life points, each hit costs it 1pt, when it +# gets to 0 it dies, gets a reward of -10 and the episode is over. At +# every step it recovers 1/20th of a life point, with a maximum of +# 5pt. +# +# The agent can carry "keys" ("a", "b", "c") that open "vaults" ("A", +# "B", "C"). They keys can only be used in sequence: initially the +# agent can move only to free spaces, or to the "a", in which case it +# now carries it, and can move to free spaces or the "A". When it +# moves to the "A", it gets a reward and loses the "a", but can now +# move to the "b", etc. Rewards are 1 for "A" and "B" and 10 for "C". + +###################################################################### + +import torch + +from torch.nn.functional import conv2d + +###################################################################### + + +class PicroCrafterEngine: + def __init__( + self, + world_height=27, + world_width=27, + nb_walls=27, + margin=2, + view_height=5, + view_width=5, + device=torch.device("cpu"), + ): + assert (world_height - 2 * margin) % (view_height - 2 * margin) == 0 + assert (world_width - 2 * margin) % (view_width - 2 * margin) == 0 + + self.device = device + + self.world_height = world_height + self.world_width = world_width + self.margin = margin + self.view_height = view_height + self.view_width = view_width + self.nb_walls = nb_walls + self.life_level_max = 5 + self.life_level_gain_100th = 5 + self.reward_per_hit = -1 + self.reward_death = -10 + + self.tokens = " +#@$aAbBcC" + self.token2id = dict([(t, n) for n, t in enumerate(self.tokens)]) + self.id2token = dict([(n, t) for n, t in enumerate(self.tokens)]) + + self.next_object = dict( + [ + (self.token2id[s], self.token2id[t]) + for (s, t) in [ + ("a", "A"), + ("A", "b"), + ("b", "B"), + ("B", "c"), + ("c", "C"), + ] + ] + ) + + self.object_reward = dict( + [ + (self.token2id[t], r) + for (t, r) in [ + ("a", 0), + ("A", 1), + ("b", 0), + ("B", 1), + ("c", 0), + ("C", 10), + ] + ] + ) + + self.acessible_object_to_inventory = dict( + [ + (self.token2id[s], self.token2id[t]) + for (s, t) in [ + ("a", " "), + ("A", "a"), + ("b", " "), + ("B", "b"), + ("c", " "), + ("C", " "), + ] + ] + ) + + def reset(self, nb_agents): + self.worlds = self.create_worlds( + nb_agents, self.world_height, self.world_width, self.nb_walls, self.margin + ).to(self.device) + self.life_level_in_100th = torch.full( + (nb_agents,), self.life_level_max * 100, device=self.device + ) + self.accessible_object = torch.full( + (nb_agents,), self.token2id["a"], device=self.device + ) + + def create_mazes(self, nb, height, width, nb_walls): + m = torch.zeros(nb, height, width, dtype=torch.int64, device=self.device) + m[:, 0, :] = 1 + m[:, -1, :] = 1 + m[:, :, 0] = 1 + m[:, :, -1] = 1 + + i = torch.arange(height, device=m.device)[None, :, None] + j = torch.arange(width, device=m.device)[None, None, :] + + for _ in range(nb_walls): + q = torch.rand(m.size(), device=m.device).flatten(1).sort(-1).indices * ( + (1 - m) * (i % 2 == 0) * (j % 2 == 0) + ).flatten(1) + q = (q == q.max(dim=-1, keepdim=True).values).long().view(m.size()) + a = q[:, None].expand(-1, 4, -1, -1).clone() + a[:, 0, :-1, :] += q[:, 1:, :] + a[:, 0, :-2, :] += q[:, 2:, :] + a[:, 1, 1:, :] += q[:, :-1, :] + a[:, 1, 2:, :] += q[:, :-2, :] + a[:, 2, :, :-1] += q[:, :, 1:] + a[:, 2, :, :-2] += q[:, :, 2:] + a[:, 3, :, 1:] += q[:, :, :-1] + a[:, 3, :, 2:] += q[:, :, :-2] + a = a[ + torch.arange(a.size(0), device=a.device), + torch.randint(4, (a.size(0),), device=a.device), + ] + m = (m + q + a).clamp(max=1) + + return m + + def create_worlds(self, nb, height, width, nb_walls, margin=2): + margin -= 1 # The maze adds a wall all around + m = self.create_mazes(nb, height - 2 * margin, width - 2 * margin, nb_walls) + q = m.flatten(1) + z = "@aAbBcC$$$$$" # What to add to the maze + u = torch.rand(q.size(), device=q.device) * (1 - q) + r = u.sort(dim=-1, descending=True).indices[:, : len(z)] + + q *= self.token2id["#"] + q[ + torch.arange(q.size(0), device=q.device)[:, None].expand_as(r), r + ] = torch.tensor([self.token2id[c] for c in z], device=q.device)[None, :] + + if margin > 0: + r = m.new_full( + (m.size(0), m.size(1) + margin * 2, m.size(2) + margin * 2), + self.token2id["+"], + ) + r[:, margin:-margin, margin:-margin] = m + m = r + return m + + def nb_actions(self): + return 5 + + def nb_view_tokens(self): + return len(self.tokens) + + def min_max_reward(self): + return ( + min(4 * self.reward_per_hit, self.reward_death), + max(self.object_reward.values()), + ) + + def step(self, actions): + a = (self.worlds == self.token2id["@"]).nonzero() + self.worlds[a[:, 0], a[:, 1], a[:, 2]] = self.token2id[" "] + s = torch.tensor([[0, 0], [-1, 0], [0, 1], [1, 0], [0, -1]], device=self.device) + b = a.clone() + b[:, 1:] = b[:, 1:] + s[actions[b[:, 0]]] + + # position is empty + o = (self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.token2id[" "]).long() + # or it is the next accessible object + q = ( + self.worlds[b[:, 0], b[:, 1], b[:, 2]] == self.accessible_object[b[:, 0]] + ).long() + o = (o + q).clamp(max=1)[:, None] + b = (1 - o) * a + o * b + self.worlds[b[:, 0], b[:, 1], b[:, 2]] = self.token2id["@"] + + nb_hits = self.monster_moves() + + alive_before = self.life_level_in_100th > 0 + self.life_level_in_100th[alive_before] = ( + self.life_level_in_100th[alive_before] + + self.life_level_gain_100th + - nb_hits[alive_before] * 100 + ).clamp(max=self.life_level_max * 100) + alive_after = self.life_level_in_100th > 0 + self.worlds[torch.logical_not(alive_after)] = self.token2id["#"] + reward = nb_hits * self.reward_per_hit + + for i in range(q.size(0)): + if q[i] == 1: + reward[i] += self.object_reward[self.accessible_object[i].item()] + self.accessible_object[i] = self.next_object[ + self.accessible_object[i].item() + ] + + reward = ( + reward + alive_before.long() * (1 - alive_after.long()) * self.reward_death + ) + inventory = torch.tensor( + [ + self.acessible_object_to_inventory[s.item()] + for s in self.accessible_object + ] + ) + + reward[torch.logical_not(alive_before)] = 0 + return reward, inventory, self.life_level_in_100th // 100 + + def monster_moves(self): + # Current positions of the monsters + m = (self.worlds == self.token2id["$"]).long().flatten(1) + + # Total number of monsters + n = m.sum(-1).max() + + # Create a tensor with one channel per monster + r = ( + (torch.rand(m.size(), device=m.device) * m) + .sort(dim=-1, descending=True) + .indices[:, :n] + ) + o = m.new_zeros((m.size(0), n) + m.size()[1:]) + i = torch.arange(o.size(0), device=o.device)[:, None].expand(-1, o.size(1)) + j = torch.arange(o.size(1), device=o.device)[None, :].expand(o.size(0), -1) + o[i, j, r] = 1 + o = o * m[:, None] + + # Create the tensor of possible motions + o = o.view((self.worlds.size(0), n) + self.worlds.flatten(1).size()[1:]) + move_kernel = torch.tensor( + [[[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [0.0, 1.0, 0.0]]]], device=o.device + ) + + p = ( + conv2d( + o.view( + o.size(0) * o.size(1), 1, self.worlds.size(-2), self.worlds.size(-1) + ).float(), + move_kernel, + padding=1, + ).view(o.size()) + == 1.0 + ).long() + + # Let's do the moves per say + i = torch.arange(self.worlds.size(0), device=self.worlds.device)[ + :, None + ].expand_as(r) + + for n in range(p.size(1)): + u = o[:, n].sort(dim=-1, descending=True).indices[:, :1] + q = p[:, n] * (self.worlds.flatten(1) == self.token2id[" "]) + o[:, n] + r = ( + (q * torch.rand(q.size(), device=q.device)) + .sort(dim=-1, descending=True) + .indices[:, :1] + ) + self.worlds.flatten(1)[i, u] = self.token2id[" "] + self.worlds.flatten(1)[i, r] = self.token2id["$"] + + nb_hits = ( + ( + conv2d( + (self.worlds == self.token2id["$"]).float()[:, None], + move_kernel, + padding=1, + ) + .long() + .squeeze(1) + * (self.worlds == self.token2id["@"]).long() + ) + .flatten(1) + .sum(-1) + ) + + return nb_hits + + def views(self): + i_height, i_width = ( + self.view_height - 2 * self.margin, + self.view_width - 2 * self.margin, + ) + a = (self.worlds == self.token2id["@"]).nonzero() + y = i_height * ((a[:, 1] - self.margin) // i_height) + x = i_width * ((a[:, 2] - self.margin) // i_width) + n = a[:, 0][:, None, None].expand(-1, self.view_height, self.view_width) + i = ( + torch.arange(self.view_height, device=a.device)[None, :, None] + + y[:, None, None] + ).expand_as(n) + j = ( + torch.arange(self.view_width, device=a.device)[None, None, :] + + x[:, None, None] + ).expand_as(n) + v = self.worlds.new_full( + (self.worlds.size(0), self.view_height, self.view_width), self.token2id["#"] + ) + + v[a[:, 0]] = self.worlds[n, i, j] + + return v + + def print_worlds( + self, src=None, comments=[], width=None, printer=print, ansi_term=False + ): + if src is None: + src = self.worlds + + if width is None: + width = src.size(2) + + def token(n): + n = n.item() + if n in self.id2token: + return self.id2token[n] + else: + return "?" + + for k in range(src.size(1)): + s = ["".join([token(n) for n in m[k]]) for m in src] + s = [r + " " * (width - len(r)) for r in s] + if ansi_term: + + def colorize(x): + for u, c in [("#", 40), ("$", 31), ("@", 32)] + [ + (x, 36) for x in "aAbBcC" + ]: + x = x.replace(u, f"\u001b[{c}m{u}\u001b[0m") + return x + + s = [colorize(x) for x in s] + printer(" | ".join(s)) + + s = [c + " " * (width - len(c)) for c in comments] + printer(" | ".join(s)) + + +###################################################################### + +if __name__ == "__main__": + import os, time + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + ansi_term = False + # nb_agents, nb_iter, display = 1000, 100, False + nb_agents, nb_iter, display = 3, 10000, True + ansi_term = True + + start_time = time.perf_counter() + engine = PicroCrafterEngine( + world_height=27, + world_width=27, + nb_walls=35, + view_height=9, + view_width=9, + margin=4, + device=device, + ) + + engine.reset(nb_agents) + + print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s") + + start_time = time.perf_counter() + + for k in range(nb_iter): + action = torch.randint(engine.nb_actions(), (nb_agents,), device=device) + rewards, inventories, life_levels = engine.step( + torch.randint(engine.nb_actions(), (nb_agents,), device=device) + ) + + if display: + os.system("clear") + engine.print_worlds( + ansi_term=ansi_term, + ) + print() + engine.print_worlds( + src=engine.views(), + comments=[ + f"L{p}I{engine.id2token[s.item()]}R{r}" + for p, s, r in zip(life_levels, inventories, rewards) + ], + width=engine.world_width, + ansi_term=ansi_term, + ) + time.sleep(0.5) + + if (life_levels > 0).long().sum() == 0: + break + + print( + f"timing {(nb_agents*nb_iter)/(time.perf_counter() - start_time)} iteration per s" + )