From: François Fleuret Date: Sun, 12 Nov 2023 07:14:52 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f420983d7ce3a2868508433d71a8e65266c71d98;p=pytorch.git Update. --- diff --git a/picocrafter.py b/picocrafter.py index 36088ac..5bd6a48 100755 --- a/picocrafter.py +++ b/picocrafter.py @@ -74,7 +74,9 @@ def to_unicode(s): def fusion_multi_lines(l, width_min=0): - l = [x if type(x) is list else [str(x)] for x in l] + l = [x if type(x) is str else str(x) for x in l] + + l = [x.split("\n") for x in l] def center(r, w): k = w - len(r) @@ -90,7 +92,7 @@ def fusion_multi_lines(l, width_min=0): return "\n".join(["|".join([o[k] for o in l]) for k in range(h)]) -class PicroCrafterEngine: +class PicroCrafterEnvironment: def __init__( self, world_height=27, @@ -246,7 +248,7 @@ class PicroCrafterEngine: else: return "?" - def nb_view_tiles(self): + def nb_state_token_values(self): return len(self.tiles) def min_max_reward(self): @@ -277,14 +279,18 @@ class PicroCrafterEngine: nb_hits = self.monster_moves() - alive_before = self.life_level_in_100th > 99 + alive_before = self.life_level_in_100th >= 100 + self.life_level_in_100th[alive_before] = ( self.life_level_in_100th[alive_before] + self.life_level_gain_100th - nb_hits[alive_before] * 100 ).clamp(max=self.life_level_max * 100 + 99) - alive_after = self.life_level_in_100th > 99 + + alive_after = self.life_level_in_100th >= 100 + self.worlds[torch.logical_not(alive_after)] = self.tile2id["#"] + reward = nb_hits * self.reward_per_hit for i in range(q.size(0)): @@ -311,6 +317,7 @@ class PicroCrafterEngine: ) reward[torch.logical_not(alive_before)] = 0 + return reward, inventory, self.life_level_in_100th // 100 def monster_moves(self): @@ -382,7 +389,10 @@ class PicroCrafterEngine: return nb_hits - def views(self): + def state_size(self): + return (self.view_height + 1) * self.view_width + + def state(self): i_height, i_width = ( self.view_height - 2 * self.world_margin, self.view_width - 2 * self.world_margin, @@ -418,9 +428,9 @@ class PicroCrafterEngine: device=v.device, ) - return v + return v.flatten(1), self.life_level_in_100th >= 100 - def seq2tiles(self, t, width=None): + def state2str(self, t, width=None): def tile(n): n = n.item() if n in self.id2tile: @@ -429,14 +439,14 @@ class PicroCrafterEngine: return "?" if t.dim() == 2: - return [self.seq2tiles(r, width) for r in t] + return [self.state2str(r, width) for r in t] if width is None: width = self.view_width t = t.reshape(-1, width) - t = ["".join([tile(n) for n in r]) for r in t] + t = "\n".join(["".join([tile(n) for n in r]) for r in t]) return t @@ -461,7 +471,7 @@ if __name__ == "__main__": char_conv = lambda x: to_ansi(to_unicode(x)) start_time = time.perf_counter() - engine = PicroCrafterEngine( + environment = PicroCrafterEnvironment( world_height=27, world_width=27, nb_walls=35, @@ -471,7 +481,7 @@ if __name__ == "__main__": device=device, ) - engine.reset(nb_agents) + environment.reset(nb_agents) print(f"timing {nb_agents/(time.perf_counter() - start_time)} init per s") @@ -487,24 +497,29 @@ if __name__ == "__main__": to_print = "" os.system("clear") - l = engine.seq2tiles(engine.worlds.flatten(1), width=engine.world_width) + l = environment.state2str( + environment.worlds.flatten(1), width=environment.world_width + ) to_print += char_conv(fusion_multi_lines(l)) + "\n\n" - views = engine.views() - action = torch.randint(engine.nb_actions(), (nb_agents,), device=device) + state, alive = environment.state() + action = alive * torch.randint( + environment.nb_actions(), (nb_agents,), device=device + ) - rewards, inventories, life_levels = engine.step(action) + rewards, inventories, life_levels = environment.step(action) if display: - l = engine.seq2tiles(views.flatten(1)) + l = environment.state2str(state) l = [ - v + [f"{engine.action2str(a.item())}/{r: 3d}"] + v + f"\n{environment.action2str(a.item())}/{r: 3d}" for (v, a, r) in zip(l, action, rewards) ] to_print += ( - char_conv(fusion_multi_lines(l, width_min=engine.world_width)) + "\n" + char_conv(fusion_multi_lines(l, width_min=environment.world_width)) + + "\n" ) print(to_print)