From: François Fleuret Date: Sun, 24 Mar 2024 08:05:51 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=545b06257dd69285c6321d7bd713d819da74953b;p=picoclvr.git Update. --- diff --git a/evasion.py b/evasion.py index a06213e..5d9547f 100755 --- a/evasion.py +++ b/evasion.py @@ -6,8 +6,19 @@ from torch.nn import functional as F ###################################################################### +nb_state_codes = 4 +nb_rewards_codes = 3 +nb_actions_codes = 5 -def generate_sequence(nb, height=6, width=6, T=10): +first_state_code = 0 +first_rewards_code = first_state_code + nb_state_codes +first_actions_code = first_rewards_code + nb_rewards_codes +nb_codes = first_actions_code + nb_actions_codes + +###################################################################### + + +def generate_episodes(nb, height=6, width=6, T=10): rnd = torch.rand(nb, height, width) rnd[:, 0, :] = 0 rnd[:, -1, :] = 0 @@ -22,14 +33,14 @@ def generate_sequence(nb, height=6, width=6, T=10): ).long().reshape(rnd.size()) rnd = rnd * (1 - wall.clamp(max=1)) - seq = wall[:, None, :, :].expand(-1, T, -1, -1).clone() + states = wall[:, None, :, :].expand(-1, T, -1, -1).clone() - agent = torch.zeros(seq.size(), dtype=torch.int64) + agent = torch.zeros(states.size(), dtype=torch.int64) agent[:, 0, 0, 0] = 1 agent_actions = torch.randint(5, (nb, T)) rewards = torch.zeros(nb, T, dtype=torch.int64) - monster = torch.zeros(seq.size(), dtype=torch.int64) + monster = torch.zeros(states.size(), dtype=torch.int64) monster[:, 0, -1, -1] = 1 monster_actions = torch.randint(5, (nb, T)) @@ -77,47 +88,82 @@ def generate_sequence(nb, height=6, width=6, T=10): assert hit.min() == 0 and hit.max() <= 1 - rewards[:, t] = -hit + (1 - hit) * agent[:, t + 1, -1, -1] + rewards[:, t + 1] = -hit + (1 - hit) * agent[:, t + 1, -1, -1] - seq += 2 * agent + 3 * monster + states += 2 * agent + 3 * monster - return seq, agent_actions, rewards + return states, agent_actions, rewards ###################################################################### -def seq2str(seq, actions, rewards): - # symbols=" #@$" - # vert, hori, cross, thin_hori = "|", "-", "+", "-" +def episodes2seq(states, actions, rewards): + states = states.flatten(2) + first_state_code + actions = actions[:, :, None] + first_actions_code + rewards = (rewards[:, :, None] + 1) + first_rewards_code + + assert ( + states.min() >= first_state_code + and states.max() < first_state_code + nb_state_codes + ) + assert ( + actions.min() >= first_actions_code + and actions.max() < first_actions_code + nb_actions_codes + ) + assert ( + rewards.min() >= first_rewards_code + and rewards.max() < first_rewards_code + nb_rewards_codes + ) + + return torch.cat([states, actions, rewards], dim=2).flatten(1) + - symbols = " █@$" - vert, hori, cross, thin_hori = "║", "═", "╬", "─" - vert, hori, cross, thin_hori = "┃", "━", "╋", "─" +def seq2episodes(seq, height, width): + seq = seq.reshape(seq.size(0), -1, height * width + 2) + states = seq[:, :, : height * width] - first_state_code + states = states.reshape(states.size(0), states.size(1), height, width) + actions = seq[:, :, height * width] - first_actions_code + rewards = seq[:, :, height * width + 1] - first_rewards_code - 1 + return states, actions, rewards - # hline = ("+" + "-" * seq.size(-1)) * seq.size(1) + "+" + "\n" - hline = (cross + hori * seq.size(-1)) * seq.size(1) + cross + "\n" + +###################################################################### + + +def episodes2str(states, actions, rewards, unicode=False, ansi_colors=False): + if unicode: + symbols = " █@$" + # vert, hori, cross, thin_hori = "║", "═", "╬", "─" + vert, hori, cross, thin_hori = "┃", "━", "╋", "─" + else: + symbols = " #@$" + vert, hori, cross, thin_hori = "|", "-", "+", "-" + + hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n" result = hline - for n in range(seq.size(0)): - for i in range(seq.size(2)): + for n in range(states.size(0)): + for i in range(states.size(2)): result += ( vert + vert.join( - ["".join([symbols[v.item()] for v in row]) for row in seq[n, :, i]] + [ + "".join([symbols[v.item()] for v in row]) + for row in states[n, :, i] + ] ) + vert + "\n" ) - # result += hline - result += (vert + thin_hori * seq.size(-1)) * seq.size(1) + vert + "\n" + result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n" def status_bar(a, r): - a = "INESW"[a.item()] - r = f"{r.item()}" - return a + " " * (seq.size(-1) - len(a) - len(r)) + r + a = "ISNEW"[a.item()] + r = "" if r == 0 else f"{r.item()}" + return a + " " * (states.size(-1) - len(a) - len(r)) + r result += ( vert @@ -128,12 +174,18 @@ def seq2str(seq, actions, rewards): result += hline + if ansi_colors: + for u, c in [("$", 31), ("@", 32)]: + result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m") + return result ###################################################################### if __name__ == "__main__": - seq, actions, rewards = generate_sequence(10, 4, 6, T=20) - - print(seq2str(seq, actions, rewards)) + height, width, T = 4, 6, 20 + states, actions, rewards = generate_episodes(3, height, width, T) + seq = episodes2seq(states, actions, rewards) + s, a, r = seq2episodes(seq, height, width) + print(episodes2str(s, a, r, unicode=True, ansi_colors=True))