Update.
authorFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 07:16:54 +0000 (09:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sat, 22 Jun 2024 07:16:54 +0000 (09:16 +0200)
12 files changed:
expr.py [deleted file]
greed.py [deleted file]
grid.py [deleted file]
main.py
maze.py [deleted file]
picoclvr.py [deleted file]
qmlp.py [deleted file]
rpl.py [deleted file]
snake.py [deleted file]
stack.py [deleted file]
tasks.py
turing.py [deleted file]

diff --git a/expr.py b/expr.py
deleted file mode 100755 (executable)
index 685efd3..0000000
--- a/expr.py
+++ /dev/null
@@ -1,105 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, re
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-
-def random_var(nb_variables=None, variables=None):
-    if variables is None:
-        return chr(ord("A") + torch.randint(nb_variables, (1,)).item())
-    else:
-        l = list(variables)
-        return l[torch.randint(len(l), (1,)).item()]
-
-
-def random_expr(variables, operand_max, budget):
-    if budget <= 5:
-        op = torch.randint(2, (1,)).item()
-        if op == 0 and len(variables) > 0:
-            return random_var(variables=variables)
-        else:
-            return str(torch.randint(operand_max + 1, (1,)).item())
-    else:
-        op = torch.randint(3, (1,)).item()
-        if op == 0:
-            e = random_expr(variables, operand_max, budget - 2)
-            if ("+" in e or "-" in e or "*" in e) and (e[0] != "(" or e[-1] != ")"):
-                return "(" + e + ")"
-            else:
-                return e
-        else:
-            b = 2 + torch.randint(budget - 5, (1,)).item()
-            e1 = random_expr(variables, operand_max, b)
-            e2 = random_expr(variables, operand_max, budget - b - 1)
-            if op == 1:
-                return e1 + "+" + e2
-            elif op == 2:
-                return e1 + "*" + e2
-
-
-def generate_program(nb_variables, operand_max, length):
-    s = ""
-    variables = set()
-
-    while len(s) < length:
-        v = random_var(nb_variables=nb_variables)
-        s += v + "=" + random_expr(variables, operand_max, budget=20) + ";"
-        variables.add(v)
-
-    return s, variables
-
-
-def generate_sequences(nb, nb_variables=5, length=20, operand_max=9, result_max=99):
-    assert nb_variables <= 26
-    sequences = []
-
-    for n in range(nb):
-        # We take length itself half of the time, and uniform between
-        # 1 and length otherwise. The actual length can be slightly
-        # greater
-
-        l = min(length, 1 + torch.randint(length * 2, (1,)).item())
-        result = None
-        while result == None or max(result.values()) > result_max:
-            p, v = generate_program(nb_variables, operand_max, l)
-            v = ", ".join(['"' + v + '": ' + v for v in v])
-            ldict = {}
-            exec(p + "result={" + v + "}", globals(), ldict)
-            result = ldict["result"]
-
-        k = list(result.keys())
-        k.sort()
-        sequences.append(p + " " + "".join([v + ":" + str(result[v]) + ";" for v in k]))
-
-    return sequences
-
-
-def extract_results(seq):
-    f = lambda a: (a[0], -1 if a[1] == "" else int(a[1]))
-    results = [
-        dict([f(tuple(x.split(":"))) for x in re.findall("[A-Z]:[0-9]*", s)])
-        for s in seq
-    ]
-    return results
-
-
-if __name__ == "__main__":
-    import time
-
-    start_time = time.perf_counter()
-    sequences = generate_sequences(1000, length=40)
-    end_time = time.perf_counter()
-    for s in sequences[:10]:
-        print(s)
-    print(f"{len(sequences) / (end_time - start_time):.02f} samples per second")
-
-    print(extract_results(sequences[:10]))
diff --git a/greed.py b/greed.py
deleted file mode 100755 (executable)
index 1025d7c..0000000
--- a/greed.py
+++ /dev/null
@@ -1,358 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import torch
-
-from torch.nn import functional as F
-
-######################################################################
-
-REWARD_PLUS = 1
-REWARD_NONE = 0
-REWARD_MINUS = -1
-REWARD_UNKNOWN = 2
-
-
-class GreedWorld:
-    def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
-        self.height = height
-        self.width = width
-        self.T = T
-        self.nb_walls = nb_walls
-        self.nb_coins = nb_coins
-
-        self.nb_states_codes = 5
-        self.nb_actions_codes = 5
-        self.nb_rewards_codes = 3
-        self.nb_lookahead_rewards_codes = 4  # stands for -1, 0, +1, and UNKNOWN
-
-        self.first_states_code = 0
-        self.first_actions_code = self.first_states_code + self.nb_states_codes
-        self.first_rewards_code = self.first_actions_code + self.nb_actions_codes
-        self.first_lookahead_rewards_code = (
-            self.first_rewards_code + self.nb_rewards_codes
-        )
-        self.nb_codes = (
-            self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes
-        )
-
-        self.state_len = self.height * self.width
-        self.index_lookahead_reward = 0
-        self.index_states = 1
-        self.index_reward = self.state_len + 1
-        self.index_action = self.state_len + 2
-        self.it_len = self.state_len + 3  # lookahead_reward / state / reward / action
-
-    def state2code(self, r):
-        return r + self.first_states_code
-
-    def code2state(self, r):
-        return r - self.first_states_code
-
-    def action2code(self, r):
-        return r + self.first_actions_code
-
-    def code2action(self, r):
-        return r - self.first_actions_code
-
-    def reward2code(self, r):
-        return r + 1 + self.first_rewards_code
-
-    def code2reward(self, r):
-        return r - self.first_rewards_code - 1
-
-    def lookahead_reward2code(self, r):
-        # -1, 0, +1 or 2 for UNKNOWN
-        return r + 1 + self.first_lookahead_rewards_code
-
-    def code2lookahead_reward(self, r):
-        return r - self.first_lookahead_rewards_code - 1
-
-    ######################################################################
-
-    def generate_episodes(self, nb):
-        rnd = torch.rand(nb, self.height, self.width)
-        rnd[:, 0, :] = 0
-        rnd[:, -1, :] = 0
-        rnd[:, :, 0] = 0
-        rnd[:, :, -1] = 0
-        wall = 0
-        for k in range(self.nb_walls):
-            wall = wall + (
-                rnd.flatten(1).argmax(dim=1)[:, None]
-                == torch.arange(rnd.flatten(1).size(1))[None, :]
-            ).long().reshape(rnd.size())
-
-            rnd = rnd * (1 - wall.clamp(max=1))
-
-        rnd = torch.rand(nb, self.height, self.width)
-        rnd[:, 0, 0] = 0  # Do not put coin at the agent's starting
-        # position
-        coins = torch.zeros(nb, self.T, self.height, self.width, dtype=torch.int64)
-        rnd = rnd * (1 - wall.clamp(max=1))
-        for k in range(self.nb_coins):
-            coins[:, 0] = coins[:, 0] + (
-                rnd.flatten(1).argmax(dim=1)[:, None]
-                == torch.arange(rnd.flatten(1).size(1))[None, :]
-            ).long().reshape(rnd.size())
-
-            rnd = rnd * (1 - coins[:, 0].clamp(max=1))
-
-        states = wall[:, None, :, :].expand(-1, self.T, -1, -1).clone()
-
-        agent = torch.zeros(states.size(), dtype=torch.int64)
-        agent[:, 0, 0, 0] = 1
-        agent_actions = torch.randint(5, (nb, self.T))
-        rewards = torch.zeros(nb, self.T, dtype=torch.int64)
-
-        troll = torch.zeros(states.size(), dtype=torch.int64)
-        troll[:, 0, -1, -1] = 1
-        troll_actions = torch.randint(5, (nb, self.T))
-
-        all_moves = agent.new(nb, 5, self.height, self.width)
-        for t in range(self.T - 1):
-            all_moves.zero_()
-            all_moves[:, 0] = agent[:, t]
-            all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
-            all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
-            all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
-            all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
-            a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
-            after_move = (all_moves * a).sum(dim=1)
-            collision = (
-                (after_move * (1 - wall) * (1 - troll[:, t]))
-                .flatten(1)
-                .sum(dim=1)[:, None, None]
-                == 0
-            ).long()
-            agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
-
-            all_moves.zero_()
-            all_moves[:, 0] = troll[:, t]
-            all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
-            all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
-            all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
-            all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
-            a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
-            after_move = (all_moves * a).sum(dim=1)
-            collision = (
-                (after_move * (1 - wall) * (1 - agent[:, t + 1]))
-                .flatten(1)
-                .sum(dim=1)[:, None, None]
-                == 0
-            ).long()
-            troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
-
-            hit = (
-                (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
-                + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :])
-                .flatten(1)
-                .sum(dim=1)
-                + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1])
-                .flatten(1)
-                .sum(dim=1)
-                + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:])
-                .flatten(1)
-                .sum(dim=1)
-            )
-            hit = (hit > 0).long()
-
-            # assert hit.min() == 0 and hit.max() <= 1
-
-            got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
-            coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
-
-            rewards[:, t + 1] = -hit + (1 - hit) * got_coin
-
-        states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
-
-        return states, agent_actions, rewards
-
-    ######################################################################
-
-    def episodes2seq(self, states, actions, rewards):
-        neg = rewards.new_zeros(rewards.size())
-        pos = rewards.new_zeros(rewards.size())
-        for t in range(neg.size(1)):
-            neg[:, t] = rewards[:, t:].min(dim=-1).values
-            pos[:, t] = rewards[:, t:].max(dim=-1).values
-        s = (neg < 0).long() * neg + (neg >= 0).long() * pos
-
-        return torch.cat(
-            [
-                self.lookahead_reward2code(s[:, :, None]),
-                self.state2code(states.flatten(2)),
-                self.reward2code(rewards[:, :, None]),
-                self.action2code(actions[:, :, None]),
-            ],
-            dim=2,
-        ).flatten(1)
-
-    def seq2episodes(self, seq):
-        seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
-        lookahead_rewards = self.code2lookahead_reward(
-            seq[:, :, self.index_lookahead_reward]
-        )
-        states = self.code2state(
-            seq[:, :, self.index_states : self.height * self.width + self.index_states]
-        )
-        states = states.reshape(states.size(0), states.size(1), self.height, self.width)
-        actions = self.code2action(seq[:, :, self.index_action])
-        rewards = self.code2reward(seq[:, :, self.index_reward])
-        return lookahead_rewards, states, actions, rewards
-
-    def seq2str(self, seq):
-        def token2str(t):
-            if (
-                t >= self.first_states_code
-                and t < self.first_states_code + self.nb_states_codes
-            ):
-                return "_#@T$"[t - self.first_states_code]
-            elif (
-                t >= self.first_actions_code
-                and t < self.first_actions_code + self.nb_actions_codes
-            ):
-                return "ISNEW"[t - self.first_actions_code]
-            elif (
-                t >= self.first_rewards_code
-                and t < self.first_rewards_code + self.nb_rewards_codes
-            ):
-                return "-0+"[t - self.first_rewards_code]
-            elif (
-                t >= self.first_lookahead_rewards_code
-                and t
-                < self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes
-            ):
-                return "n.pU"[t - self.first_lookahead_rewards_code]
-            else:
-                return "?"
-
-        return ["".join([token2str(x.item()) for x in row]) for row in seq]
-
-    ######################################################################
-
-    def episodes2str(
-        self,
-        lookahead_rewards,
-        states,
-        actions,
-        rewards,
-        unicode=False,
-        ansi_colors=False,
-    ):
-        if unicode:
-            symbols = "·█@T$"
-            # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
-            vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"
-        else:
-            symbols = " #@T$"
-            vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-"
-
-        hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n"
-
-        result = hline
-
-        for n in range(states.size(0)):
-
-            def state_symbol(v):
-                v = v.item()
-                return "?" if v < 0 or v >= len(symbols) else symbols[v]
-
-            for i in range(states.size(2)):
-                result += (
-                    vert
-                    + vert.join(
-                        [
-                            "".join([state_symbol(v) for v in row])
-                            for row in states[n, :, i]
-                        ]
-                    )
-                    + vert
-                    + "\n"
-                )
-
-            # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n"
-
-            def status_bar(a, r, lr=None):
-                a, r = a.item(), r.item()
-                sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
-                sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
-                if lr is None:
-                    sb_lr = ""
-                else:
-                    lr = lr.item()
-                    sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?"
-                return (
-                    sb_a
-                    + "/"
-                    + sb_r
-                    + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr))
-                    + sb_lr
-                )
-
-            result += (
-                vert
-                + vert.join(
-                    [
-                        status_bar(a, r, lr)
-                        for a, r, lr in zip(
-                            actions[n], rewards[n], lookahead_rewards[n]
-                        )
-                    ]
-                )
-                + vert
-                + "\n"
-            )
-
-            result += hline
-
-        if ansi_colors:
-            for u, c in [("T", 31), ("@", 32), ("$", 34)]:
-                result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
-
-        return result
-
-    ######################################################################
-
-    def save_seq_as_anim_script(self, seq, filename):
-        it_len = self.height * self.width + 3
-
-        seq = (
-            seq.reshape(seq.size(0), -1, it_len)
-            .permute(1, 0, 2)
-            .reshape(self.T, seq.size(0), -1)
-        )
-
-        with open(filename, "w") as f:
-            for t in range(self.T):
-                # f.write("clear\n")
-                f.write("cat << EOF\n")
-                f.write("\u001b[H")
-                # for i in range(seq.size(2)):
-                # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], self.height, self.width)
-                lr, s, a, r = self.seq2episodes(seq[t : t + 1, :].reshape(8, -1))
-                f.write(self.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
-                f.write("EOF\n")
-                f.write("sleep 0.25\n")
-            print(f"Saved {filename}")
-
-
-if __name__ == "__main__":
-    gw = GreedWorld(height=5, width=7, T=10, nb_walls=4, nb_coins=2)
-    states, actions, rewards = gw.generate_episodes(nb=6)
-    seq = gw.episodes2seq(states, actions, rewards)
-    lr, s, a, r = gw.seq2episodes(seq)
-    print(gw.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
-
-    print()
-    for s in gw.seq2str(seq):
-        print(s)
-
-    gw = GreedWorld(height=5, width=7, T=100, nb_walls=4, nb_coins=2)
-    states, actions, rewards = gw.generate_episodes(nb=128)
-    seq = gw.episodes2seq(states, actions, rewards)
-    gw.save_seq_as_anim_script(seq, "anim.sh")
diff --git a/grid.py b/grid.py
deleted file mode 100755 (executable)
index 1287ad5..0000000
--- a/grid.py
+++ /dev/null
@@ -1,323 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math
-import torch, torchvision
-import torch.nn.functional as F
-
-######################################################################
-
-
-class GridFactory:
-    def __init__(
-        self,
-        size=6,
-        max_nb_items=4,
-        max_nb_transformations=3,
-        nb_questions=4,
-        nb_shapes=6,
-        nb_colors=6,
-        nb_play_steps=3,
-    ):
-        assert size % 2 == 0
-        self.size = size
-        self.max_nb_items = max_nb_items
-        self.max_nb_transformations = max_nb_transformations
-        self.nb_questions = nb_questions
-        self.nb_play_steps = nb_play_steps
-        self.name_shapes = ["A", "B", "C", "D", "E", "F"]
-        self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
-        self.vname_shapes = ["vA", "vB", "vC", "vD", "vE", "vF"]
-        self.vname_colors = ["vred", "vyellow", "vblue", "vgreen", "vwhite", "vpurple"]
-
-    def generate_scene(self):
-        nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
-        col = torch.full((self.size * self.size,), -1)
-        shp = torch.full((self.size * self.size,), -1)
-        a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
-        col[:nb_items] = a % len(self.name_colors)
-        shp[:nb_items] = a // len(self.name_colors)
-        i = torch.randperm(self.size * self.size)
-        col = col[i]
-        shp = shp[i]
-        return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
-
-    def random_object_move(self, scene):
-        col, shp = scene
-        while True:
-            a = (col.flatten() >= 0).nonzero()
-            a = a[torch.randint(a.size(0), (1,)).item()]
-            i, j = a // self.size, a % self.size
-            assert col[i, j] >= 0
-            dst = [(i, j), (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]
-            dst = list(
-                filter(
-                    lambda x: x[0] >= 0
-                    and x[1] >= 0
-                    and x[0] < self.size
-                    and x[1] < self.size
-                    and col[x[0], x[1]] < 0,
-                    dst,
-                )
-            )
-            if len(dst) > 0:
-                ni, nj = dst[torch.randint(len(dst), (1,)).item()]
-                col[ni, nj] = col[i, j]
-                shp[ni, nj] = shp[i, j]
-                col[i, j] = -1
-                shp[i, j] = -1
-                break
-
-        return col, shp
-
-    def transformation(self, t, scene):
-        col, shp = scene
-        if t == 0:
-            col, shp = col.flip(0), shp.flip(0)
-            description = "<chg> vertical flip"
-        elif t == 1:
-            col, shp = col.flip(1), shp.flip(1)
-            description = "<chg> horizontal flip"
-        elif t == 2:
-            col, shp = col.flip(0).t(), shp.flip(0).t()
-            description = "<chg> rotate 90 degrees"
-        elif t == 3:
-            col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
-            description = "<chg> rotate 180 degrees"
-        elif t == 4:
-            col, shp = col.flip(1).t(), shp.flip(1).t()
-            description = "<chg> rotate 270 degrees"
-
-        return (col.contiguous(), shp.contiguous()), description
-
-    def random_transformations(self, scene):
-        descriptions = []
-        nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
-        transformations = torch.randint(5, (nb_transformations,))
-
-        for t in transformations:
-            scene, description = self.transformation(t, scene)
-            descriptions += [description]
-
-        return scene, descriptions
-
-    def visual_scene2str(self, scene):
-        col, shp = scene
-        r = []
-        for i in range(self.size):
-            s = []
-            for j in range(self.size):
-                if col[i, j] >= 0:
-                    s += [self.vname_colors[col[i, j]], self.vname_shapes[shp[i, j]]]
-                else:
-                    s += ["v_", "v+"]
-            r += s  # .append(" ".join(s))
-        return " ".join(r)
-
-    def print_scene(self, scene):
-        col, shp = scene
-
-        # for i in range(self.size):
-        # for j in range(self.size):
-        # if col[i,j] >= 0:
-        # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
-
-        for i in range(self.size):
-            for j in range(self.size):
-                if col[i, j] >= 0:
-                    print(
-                        f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
-                        end="",
-                    )
-                elif j == 0:
-                    print(" +", end="")
-                else:
-                    print("-+", end="")
-                if j < self.size - 1:
-                    print("--", end="")
-                else:
-                    print("")
-            if i < self.size - 1:
-                for j in range(self.size - 1):
-                    print(" |  ", end="")
-                print(" |")
-
-    def grid_positions(self, scene):
-        col, shp = scene
-
-        properties = []
-
-        for i in range(self.size):
-            for j in range(self.size):
-                if col[i, j] >= 0:
-                    n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
-                    properties += [f"a {n} at {i} {j}"]
-
-        return properties
-
-    def all_properties(self, scene):
-        col, shp = scene
-
-        properties = []
-
-        for i1 in range(self.size):
-            for j1 in range(self.size):
-                if col[i1, j1] >= 0:
-                    n1 = (
-                        f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
-                    )
-                    properties += [f"there is a {n1}"]
-                    if i1 < self.size // 2:
-                        properties += [f"a {n1} is in the top half"]
-                    if i1 >= self.size // 2:
-                        properties += [f"a {n1} is in the bottom half"]
-                    if j1 < self.size // 2:
-                        properties += [f"a {n1} is in the left half"]
-                    if j1 >= self.size // 2:
-                        properties += [f"a {n1} is in the right half"]
-                    for i2 in range(self.size):
-                        for j2 in range(self.size):
-                            if col[i2, j2] >= 0:
-                                n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
-                                if i1 > i2:
-                                    properties += [f"a {n1} is below a {n2}"]
-                                if i1 < i2:
-                                    properties += [f"a {n1} is above a {n2}"]
-                                if j1 > j2:
-                                    properties += [f"a {n1} is right of a {n2}"]
-                                if j1 < j2:
-                                    properties += [f"a {n1} is left of a {n2}"]
-                                if abs(i1 - i2) + abs(j1 - j2) == 1:
-                                    properties += [f"a {n1} is next to a {n2}"]
-
-        return properties
-
-    def generate_scene_and_play(self):
-        scene = self.generate_scene()
-        steps = [self.visual_scene2str(scene)]
-        for t in range(self.nb_play_steps - 1):
-            if torch.randint(4, (1,)).item() == 0:
-                scene, _ = self.transformation(torch.randint(5, (1,)), scene)
-            else:
-                scene = self.random_object_move(scene)
-            steps.append(self.visual_scene2str(scene))
-        return " | ".join(steps)
-
-    def generate_scene_and_questions(self):
-        while True:
-            # We generate scenes until we get one with enough
-            # properties
-
-            while True:
-                start_scene = self.generate_scene()
-                scene, transformations = self.random_transformations(start_scene)
-                true = self.all_properties(scene)
-                if len(true) >= self.nb_questions:
-                    break
-
-            # We generate a bunch of false properties by shuffling the
-            # scene and sometimes adding properties from totally
-            # different scenes. We try ten times to get enough false
-            # properties and go back to generating the scene if we do
-            # not succeed
-
-            for a in range(10):
-                col, shp = scene
-                col, shp = col.view(-1), shp.view(-1)
-                p = torch.randperm(col.size(0))
-                col, shp = col[p], shp[p]
-                other_scene = (
-                    col.view(self.size, self.size),
-                    shp.view(self.size, self.size),
-                )
-
-                false = self.all_properties(other_scene)
-
-                # We sometime add properties from a totally different
-                # scene to have negative "there is a xxx xxx"
-                # properties
-
-                if torch.rand(1).item() < 0.2:
-                    other_scene = self.generate_scene()
-                    false += self.all_properties(other_scene)
-
-                false = list(set(false) - set(true))
-                if len(false) >= self.nb_questions:
-                    break
-
-            if a < 10:
-                break
-
-        true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
-        false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
-        true = ["<prop> " + q + " <ans> true" for q in true]
-        false = ["<prop> " + q + " <ans> false" for q in false]
-
-        union = true + false
-        questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
-
-        result = " ".join(
-            ["<obj> " + x for x in self.grid_positions(start_scene)]
-            + transformations
-            + questions
-        )
-
-        return start_scene, scene, result
-
-    def generate_samples(self, nb, fraction_play=0.0, progress_bar=None):
-        result = []
-
-        play = torch.rand(nb) < fraction_play
-        if progress_bar is not None:
-            play = progress_bar(play)
-
-        for p in play:
-            if p:
-                result.append(self.generate_scene_and_play())
-            else:
-                result.append(self.generate_scene_and_questions()[2])
-
-        return result
-
-
-######################################################################
-
-if __name__ == "__main__":
-    import time
-
-    grid_factory = GridFactory()
-
-    # start_time = time.perf_counter()
-    # samples = grid_factory.generate_samples(10000)
-    # end_time = time.perf_counter()
-    # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
-
-    start_scene, scene, questions = grid_factory.generate_scene_and_questions()
-    print()
-    print("-- Original scene -----------------------------")
-    print()
-    grid_factory.print_scene(start_scene)
-    print()
-    print("-- Transformed scene --------------------------")
-    print()
-    grid_factory.print_scene(scene)
-    print()
-    print("-- Sequence -----------------------------------")
-    print()
-    print(questions)
-
-    # print(grid_factory.visual_scene2str(scene))
-
-    # grid_factory.print_scene(scene)
-    # for t in range(5):
-    # scene = grid_factory.random_object_move(scene)
-    # print()
-    # grid_factory.print_scene(scene)
-
-    print(grid_factory.generate_scene_and_play())
-
-######################################################################
diff --git a/main.py b/main.py
index 6c27599..97c7130 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -29,12 +29,7 @@ parser = argparse.ArgumentParser(
     formatter_class=argparse.ArgumentDefaultsHelpFormatter,
 )
 
-parser.add_argument(
-    "--task",
-    type=str,
-    default="world",
-    help="file, byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp, greed",
-)
+parser.add_argument("--task", type=str, default="world", help="world")
 
 parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
 
@@ -78,119 +73,10 @@ parser.add_argument("--dropout", type=float, default=0.1)
 
 parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
 
-##############################
-# filetask
-
-parser.add_argument("--filetask_train_file", type=str, default=None)
-
-parser.add_argument("--filetask_test_file", type=str, default=None)
-
-##############################
-# rpl options
-
-parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
-
-parser.add_argument("--rpl_max_input", type=int, default=9)
-
-parser.add_argument("--rpl_prog_len", type=int, default=8)
-
-parser.add_argument("--rpl_nb_runs", type=int, default=5)
-
-parser.add_argument("--rpl_no_prog", action="store_true", default=False)
-
-##############################
-# grid options
-
-parser.add_argument("--grid_size", type=int, default=6)
-
-parser.add_argument("--grid_fraction_play", type=float, default=0)
-
-##############################
-# picoclvr options
-
-parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
-
-parser.add_argument("--picoclvr_height", type=int, default=12)
-
-parser.add_argument("--picoclvr_width", type=int, default=16)
-
-parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
-
-##############################
-# Maze options
-
-parser.add_argument("--maze_height", type=int, default=13)
-
-parser.add_argument("--maze_width", type=int, default=21)
-
-parser.add_argument("--maze_nb_walls", type=int, default=15)
-
-##############################
-# Snake options
-
-parser.add_argument("--snake_height", type=int, default=9)
-
-parser.add_argument("--snake_width", type=int, default=12)
-
-parser.add_argument("--snake_nb_colors", type=int, default=5)
-
-parser.add_argument("--snake_length", type=int, default=200)
-
-##############################
-# ByHeart options
-
-parser.add_argument("--byheart_separation", type=int, default=1)
-
-##############################
-# Stack options
-
-parser.add_argument("--stack_nb_steps", type=int, default=100)
-
-parser.add_argument("--stack_nb_stacks", type=int, default=3)
-
-parser.add_argument("--stack_nb_digits", type=int, default=3)
-
-parser.add_argument("--stack_fraction_values_for_train", type=float, default=None)
-
-##############################
-# Expr options
-
-parser.add_argument("--expr_nb_variables", type=int, default=5)
-
-parser.add_argument("--expr_sequence_length", type=int, default=40)
-
-parser.add_argument("--expr_operand_max", type=int, default=9)
-
-parser.add_argument("--expr_result_max", type=int, default=99)
-
-parser.add_argument("--expr_input_file", type=str, default=None)
-
-##############################
-# Mixing
-
-parser.add_argument("--mixing_hard", action="store_true", default=False)
-
-parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
-
-##############################
-# greed options
-
-parser.add_argument("--greed_height", type=int, default=5)
-
-parser.add_argument("--greed_width", type=int, default=7)
-
-parser.add_argument("--greed_T", type=int, default=25)
-
-parser.add_argument("--greed_nb_walls", type=int, default=5)
-
-parser.add_argument("--greed_nb_coins", type=int, default=2)
-
 ######################################################################
 
 args = parser.parse_args()
 
-assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
-
 if args.result_dir is None:
     args.result_dir = f"results_{args.task}"
 
@@ -203,114 +89,6 @@ default_task_args = {
         "nb_train_samples": 250000,
         "nb_test_samples": 10000,
     },
-    "file": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 250000,
-        "nb_test_samples": 10000,
-    },
-    "addition": {
-        "model": "352M",
-        "batch_size": 25,
-        "nb_train_samples": 250000,
-        "nb_test_samples": 10000,
-    },
-    "byheart": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 50000,
-        "nb_test_samples": 10000,
-    },
-    "expr": {
-        "model": "352M",
-        "batch_size": 25,
-        "nb_train_samples": 2500000,
-        "nb_test_samples": 10000,
-    },
-    "grid": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 250000,
-        "nb_test_samples": 10000,
-    },
-    "qmlp": {
-        "model": "37M",
-        "batch_size": 10,
-        "nb_train_samples": 100000,
-        "nb_test_samples": 1000,
-    },
-    "guessop": {
-        "model": "352M",
-        "batch_size": 25,
-        "nb_train_samples": 1000000,
-        "nb_test_samples": 10000,
-    },
-    "learnop": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 50000,
-        "nb_test_samples": 10000,
-    },
-    "maze": {
-        "model": "37M",
-        "batch_size": 5,
-        "nb_train_samples": 100000,
-        "nb_test_samples": 10000,
-    },
-    "picoclvr": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 250000,
-        "nb_test_samples": 10000,
-    },
-    "rpl": {
-        "model": "352M",
-        "batch_size": 5,
-        "nb_train_samples": 2500000,
-        "nb_test_samples": 10000,
-    },
-    "snake": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 250000,
-        "nb_test_samples": 10000,
-    },
-    "stack": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 100000,
-        "nb_test_samples": 1000,
-    },
-    "twotargets": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 50000,
-        "nb_test_samples": 10000,
-    },
-    "memory": {
-        "model": "37M",
-        "batch_size": 100,
-        "nb_train_samples": 25000,
-        "nb_test_samples": 1000,
-    },
-    "mixing": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 250000,
-        "nb_test_samples": 10000,
-    },
-    "mnist": {
-        "model": "37M",
-        "batch_size": 10,
-        "nb_train_samples": 60000,
-        "nb_test_samples": 10000,
-    },
-    "greed": {
-        "model": "37M",
-        "batch_size": 25,
-        "nb_train_samples": 25000,
-        "nb_test_samples": 10000,
-    },
 }
 
 if args.task in default_task_args:
@@ -406,24 +184,6 @@ for n in vars(args):
 ######################################################################
 
 
-def picoclvr_pruner_horizontal_green(p):
-    return not ("green" in p and ("left" in p or "right" in p))
-
-
-picoclvr_pruner_train = (
-    picoclvr_pruner_horizontal_green
-    if args.picocvlr_prune_properties in {"train+eval"}
-    else None
-)
-
-picoclvr_pruner_eval = (
-    (lambda p: not picoclvr_pruner_horizontal_green(p))
-    if args.picocvlr_prune_properties in {"train+eval", "eval"}
-    else None
-)
-
-######################################################################
-
 if args.physical_batch_size is None:
     args.physical_batch_size = args.batch_size
 else:
@@ -848,20 +608,22 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 accuracy_to_make_quizzes = 0.975
 
 for n_epoch in range(args.nb_epochs):
+    # select the model with lowest accuracy
     models.sort(key=lambda model: model.main_test_accuracy)
-
     model = models[0]
 
     log_string(
         f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
     )
 
+    # improve it
     one_epoch(model, task)
 
     log_string(
         f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
     )
 
+    # test it
     run_tests(model, task, deterministic_synthesis=False)
 
     if model.main_test_accuracy >= accuracy_to_make_quizzes:
diff --git a/maze.py b/maze.py
deleted file mode 100755 (executable)
index d5662f0..0000000
--- a/maze.py
+++ /dev/null
@@ -1,317 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import torch, torchvision
-
-######################################################################
-
-v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4
-
-
-def create_maze(h=11, w=17, nb_walls=8):
-    assert h % 2 == 1 and w % 2 == 1
-
-    nb_attempts, nb_added_walls = 0, 0
-
-    while nb_added_walls < nb_walls:
-        while True:
-            if nb_attempts == 0:
-                m = torch.zeros(h, w, dtype=torch.int64)
-                m[0, :] = 1
-                m[-1, :] = 1
-                m[:, 0] = 1
-                m[:, -1] = 1
-
-            r = torch.rand(4)
-
-            if r[0] <= 0.5:
-                # Add a vertical wall
-                i1, i2, j = (
-                    int((r[1] * h).item()),
-                    int((r[2] * h).item()),
-                    int((r[3] * w).item()),
-                )
-                i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
-                i1, i2 = min(i1, i2), max(i1, i2)
-
-                # If this wall does not hit another one, add it
-                if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
-                    m[i1 : i2 + 1, j] = 1
-                    break
-
-            else:
-                # Add an horizontal wall
-                i, j1, j2 = (
-                    int((r[1] * h).item()),
-                    int((r[2] * w).item()),
-                    int((r[3] * w).item()),
-                )
-                i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
-                j1, j2 = min(j1, j2), max(j1, j2)
-
-                # If this wall does not hit another one, add it
-                if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
-                    m[i, j1 : j2 + 1] = 1
-                    break
-
-            nb_attempts += 1
-
-            if nb_attempts > 10 * nb_walls:
-                nb_attempts, nb_added_walls = 0, 0
-
-        nb_added_walls += 1
-
-    return m
-
-
-######################################################################
-
-
-def compute_distance(walls, goal_i, goal_j):
-    max_length = walls.numel()
-    dist = torch.full_like(walls, max_length)
-
-    dist[goal_i, goal_j] = 0
-    pred_dist = torch.empty_like(dist)
-
-    while True:
-        pred_dist.copy_(dist)
-        d = (
-            torch.cat(
-                (
-                    dist[None, 1:-1, 0:-2],
-                    dist[None, 2:, 1:-1],
-                    dist[None, 1:-1, 2:],
-                    dist[None, 0:-2, 1:-1],
-                ),
-                0,
-            ).min(dim=0)[0]
-            + 1
-        )
-
-        dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
-        dist = walls * max_length + (1 - walls) * dist
-
-        if dist.equal(pred_dist):
-            return dist * (1 - walls)
-
-
-######################################################################
-
-
-def compute_policy(walls, goal_i, goal_j):
-    distance = compute_distance(walls, goal_i, goal_j)
-    distance = distance + walls.numel() * walls
-
-    value = distance.new_full((4,) + distance.size(), walls.numel())
-    value[0, :, 1:] = distance[:, :-1]  # <
-    value[1, :, :-1] = distance[:, 1:]  # >
-    value[2, 1:, :] = distance[:-1, :]  # ^
-    value[3, :-1, :] = distance[1:, :]  # v
-
-    proba = (value.min(dim=0)[0][None] == value).float()
-    proba = proba / proba.sum(dim=0)[None]
-    proba = proba * (1 - walls) + walls.float() / 4
-
-    return proba
-
-
-def stationary_densities(mazes, policies):
-    policies = policies * (mazes != v_goal)[:, None]
-    start = (mazes == v_start).nonzero(as_tuple=True)
-    probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
-    pred_probas = probas.clone()
-    probas[start] = 1.0
-
-    while not pred_probas.equal(probas):
-        pred_probas.copy_(probas)
-        probas.zero_()
-        probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
-        probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
-        probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
-        probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
-        probas[start] = 1.0
-
-    return probas
-
-
-######################################################################
-
-
-def mark_path(walls, i, j, goal_i, goal_j, policy):
-    action = torch.distributions.categorical.Categorical(
-        policy.permute(1, 2, 0)
-    ).sample()
-    n, nmax = 0, walls.numel()
-    while i != goal_i or j != goal_j:
-        di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]]
-        i, j = i + di, j + dj
-        assert walls[i, j] == 0
-        walls[i, j] = v_path
-        n += 1
-        assert n < nmax
-
-
-def path_optimality(ref_paths, paths):
-    return (ref_paths == v_path).long().flatten(1).sum(1) == (
-        paths == v_path
-    ).long().flatten(1).sum(1)
-
-
-def path_correctness(mazes, paths):
-    still_ok = (mazes - (paths * (paths != v_path))).view(mazes.size(0), -1).abs().sum(
-        1
-    ) == 0
-    reached = still_ok.new_zeros(still_ok.size())
-    current, pred_current = paths.clone(), paths.new_zeros(paths.size())
-    goal = (mazes == v_goal).long()
-    while not pred_current.equal(current):
-        pred_current.copy_(current)
-        u = (current == v_start).long()
-        possible_next = (
-            u[:, 2:, 1:-1] + u[:, 0:-2, 1:-1] + u[:, 1:-1, 2:] + u[:, 1:-1, 0:-2] > 0
-        ).long()
-        u = u[:, 1:-1, 1:-1]
-        reached += ((goal[:, 1:-1, 1:-1] * possible_next).sum((1, 2)) == 1) * (
-            (current == v_path).sum((1, 2)) == 0
-        )
-        current[:, 1:-1, 1:-1] = (1 - u) * current[:, 1:-1, 1:-1] + (
-            v_start - v_path
-        ) * (possible_next * (current[:, 1:-1, 1:-1] == v_path))
-        still_ok *= (current == v_start).sum((1, 2)) <= 1
-
-    return still_ok * reached
-
-
-######################################################################
-
-
-def create_maze_data(
-    nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x
-):
-    mazes = torch.empty(nb, height, width, dtype=torch.int64)
-    paths = torch.empty(nb, height, width, dtype=torch.int64)
-    policies = torch.empty(nb, 4, height, width)
-
-    for n in progress_bar(range(nb)):
-        maze = create_maze(height, width, nb_walls)
-        i = (maze == v_empty).nonzero()
-        while True:
-            start, goal = i[torch.randperm(i.size(0))[:2]]
-            if (start - goal).abs().sum() >= dist_min:
-                break
-        start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1]
-
-        policy = compute_policy(maze, goal_i, goal_j)
-        path = maze.clone()
-        mark_path(path, start_i, start_j, goal_i, goal_j, policy)
-        maze[start_i, start_j] = v_start
-        maze[goal_i, goal_j] = v_goal
-        path[start_i, start_j] = v_start
-        path[goal_i, goal_j] = v_goal
-
-        mazes[n] = maze
-        paths[n] = path
-        policies[n] = policy
-
-    return mazes, paths, policies
-
-
-######################################################################
-
-
-def save_image(
-    name,
-    mazes,
-    target_paths=None,
-    predicted_paths=None,
-    path_correct=None,
-    path_optimal=None,
-):
-    colors = torch.tensor(
-        [
-            [255, 255, 255],  # empty
-            [0, 0, 0],  # wall
-            [0, 255, 0],  # start
-            [127, 127, 255],  # goal
-            [255, 0, 0],  # path
-        ]
-    )
-
-    mazes = mazes.cpu()
-
-    c_mazes = (
-        colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
-    )
-
-    imgs = c_mazes.unsqueeze(1)
-
-    if target_paths is not None:
-        target_paths = target_paths.cpu()
-
-        c_target_paths = (
-            colors[target_paths.reshape(-1)]
-            .reshape(target_paths.size() + (-1,))
-            .permute(0, 3, 1, 2)
-        )
-
-        imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
-
-    if predicted_paths is not None:
-        predicted_paths = predicted_paths.cpu()
-        c_predicted_paths = (
-            colors[predicted_paths.reshape(-1)]
-            .reshape(predicted_paths.size() + (-1,))
-            .permute(0, 3, 1, 2)
-        )
-        imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
-
-    img = torch.tensor([255, 255, 0]).view(1, -1, 1, 1)
-
-    # NxKxCxHxW
-    if path_optimal is not None:
-        path_optimal = path_optimal.cpu().long().view(-1, 1, 1, 1)
-        img = (
-            img * (1 - path_optimal)
-            + torch.tensor([0, 255, 0]).view(1, -1, 1, 1) * path_optimal
-        )
-
-    if path_correct is not None:
-        path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
-        img = img * path_correct + torch.tensor([255, 0, 0]).view(1, -1, 1, 1) * (
-            1 - path_correct
-        )
-
-    img = img.expand(
-        -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
-    ).clone()
-
-    print(f"{img.size()=} {imgs.size()=}")
-
-    for k in range(imgs.size(1)):
-        img[
-            :,
-            :,
-            1 : 1 + imgs.size(3),
-            1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
-        ] = imgs[:, k]
-
-    img = img.float() / 255.0
-
-    torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)
-
-
-######################################################################
-
-if __name__ == "__main__":
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-    mazes, paths, policies = create_maze_data(8)
-    mazes, paths = mazes.to(device), paths.to(device)
-    save_image("test.png", mazes=mazes, target_paths=paths, predicted_paths=paths)
-    print(path_correctness(mazes, paths))
-
-######################################################################
diff --git a/picoclvr.py b/picoclvr.py
deleted file mode 100755 (executable)
index 0cd3062..0000000
+++ /dev/null
@@ -1,370 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math
-import torch, torchvision
-import torch.nn.functional as F
-
-color_name2rgb = {
-    "white": [255, 255, 255],
-    "red": [255, 0, 0],
-    "green": [0, 128, 0],
-    "blue": [0, 0, 255],
-    "yellow": [255, 255, 0],
-    "black": [0, 0, 0],
-    "maroon": [128, 0, 0],
-    "dark_red": [139, 0, 0],
-    "brown": [165, 42, 42],
-    "firebrick": [178, 34, 34],
-    "crimson": [220, 20, 60],
-    "tomato": [255, 99, 71],
-    "coral": [255, 127, 80],
-    "indian_red": [205, 92, 92],
-    "light_coral": [240, 128, 128],
-    "dark_salmon": [233, 150, 122],
-    "salmon": [250, 128, 114],
-    "light_salmon": [255, 160, 122],
-    "orange_red": [255, 69, 0],
-    "dark_orange": [255, 140, 0],
-    "orange": [255, 165, 0],
-    "gold": [255, 215, 0],
-    "dark_golden_rod": [184, 134, 11],
-    "golden_rod": [218, 165, 32],
-    "pale_golden_rod": [238, 232, 170],
-    "dark_khaki": [189, 183, 107],
-    "khaki": [240, 230, 140],
-    "olive": [128, 128, 0],
-    "yellow_green": [154, 205, 50],
-    "dark_olive_green": [85, 107, 47],
-    "olive_drab": [107, 142, 35],
-    "lawn_green": [124, 252, 0],
-    "chartreuse": [127, 255, 0],
-    "green_yellow": [173, 255, 47],
-    "dark_green": [0, 100, 0],
-    "forest_green": [34, 139, 34],
-    "lime": [0, 255, 0],
-    "lime_green": [50, 205, 50],
-    "light_green": [144, 238, 144],
-    "pale_green": [152, 251, 152],
-    "dark_sea_green": [143, 188, 143],
-    "medium_spring_green": [0, 250, 154],
-    "spring_green": [0, 255, 127],
-    "sea_green": [46, 139, 87],
-    "medium_aqua_marine": [102, 205, 170],
-    "medium_sea_green": [60, 179, 113],
-    "light_sea_green": [32, 178, 170],
-    "dark_slate_gray": [47, 79, 79],
-    "teal": [0, 128, 128],
-    "dark_cyan": [0, 139, 139],
-    "aqua": [0, 255, 255],
-    "cyan": [0, 255, 255],
-    "light_cyan": [224, 255, 255],
-    "dark_turquoise": [0, 206, 209],
-    "turquoise": [64, 224, 208],
-    "medium_turquoise": [72, 209, 204],
-    "pale_turquoise": [175, 238, 238],
-    "aqua_marine": [127, 255, 212],
-    "powder_blue": [176, 224, 230],
-    "cadet_blue": [95, 158, 160],
-    "steel_blue": [70, 130, 180],
-    "corn_flower_blue": [100, 149, 237],
-    "deep_sky_blue": [0, 191, 255],
-    "dodger_blue": [30, 144, 255],
-    "light_blue": [173, 216, 230],
-    "sky_blue": [135, 206, 235],
-    "light_sky_blue": [135, 206, 250],
-    "midnight_blue": [25, 25, 112],
-    "navy": [0, 0, 128],
-    "dark_blue": [0, 0, 139],
-    "medium_blue": [0, 0, 205],
-    "royal_blue": [65, 105, 225],
-    "blue_violet": [138, 43, 226],
-    "indigo": [75, 0, 130],
-    "dark_slate_blue": [72, 61, 139],
-    "slate_blue": [106, 90, 205],
-    "medium_slate_blue": [123, 104, 238],
-    "medium_purple": [147, 112, 219],
-    "dark_magenta": [139, 0, 139],
-    "dark_violet": [148, 0, 211],
-    "dark_orchid": [153, 50, 204],
-    "medium_orchid": [186, 85, 211],
-    "purple": [128, 0, 128],
-    "thistle": [216, 191, 216],
-    "plum": [221, 160, 221],
-    "violet": [238, 130, 238],
-    "magenta": [255, 0, 255],
-    "orchid": [218, 112, 214],
-    "medium_violet_red": [199, 21, 133],
-    "pale_violet_red": [219, 112, 147],
-    "deep_pink": [255, 20, 147],
-    "hot_pink": [255, 105, 180],
-    "light_pink": [255, 182, 193],
-    "pink": [255, 192, 203],
-    "antique_white": [250, 235, 215],
-    "beige": [245, 245, 220],
-    "bisque": [255, 228, 196],
-    "blanched_almond": [255, 235, 205],
-    "wheat": [245, 222, 179],
-    "corn_silk": [255, 248, 220],
-    "lemon_chiffon": [255, 250, 205],
-    "light_golden_rod_yellow": [250, 250, 210],
-    "light_yellow": [255, 255, 224],
-    "saddle_brown": [139, 69, 19],
-    "sienna": [160, 82, 45],
-    "chocolate": [210, 105, 30],
-    "peru": [205, 133, 63],
-    "sandy_brown": [244, 164, 96],
-    "burly_wood": [222, 184, 135],
-    "tan": [210, 180, 140],
-    "rosy_brown": [188, 143, 143],
-    "moccasin": [255, 228, 181],
-    "navajo_white": [255, 222, 173],
-    "peach_puff": [255, 218, 185],
-    "misty_rose": [255, 228, 225],
-    "lavender_blush": [255, 240, 245],
-    "linen": [250, 240, 230],
-    "old_lace": [253, 245, 230],
-    "papaya_whip": [255, 239, 213],
-    "sea_shell": [255, 245, 238],
-    "mint_cream": [245, 255, 250],
-    "slate_gray": [112, 128, 144],
-    "light_slate_gray": [119, 136, 153],
-    "light_steel_blue": [176, 196, 222],
-    "lavender": [230, 230, 250],
-    "floral_white": [255, 250, 240],
-    "alice_blue": [240, 248, 255],
-    "ghost_white": [248, 248, 255],
-    "honeydew": [240, 255, 240],
-    "ivory": [255, 255, 240],
-    "azure": [240, 255, 255],
-    "snow": [255, 250, 250],
-    "silver": [192, 192, 192],
-    "gainsboro": [220, 220, 220],
-    "white_smoke": [245, 245, 245],
-}
-
-color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())])
-color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())])
-
-######################################################################
-
-
-def all_properties(height, width, nb_squares, square_i, square_j, square_c):
-    s = []
-
-    for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]:
-        s += [f"there is {c_r}"]
-
-        if square_i[r] >= height - height // 3:
-            s += [f"{c_r} bottom"]
-        if square_i[r] < height // 3:
-            s += [f"{c_r} top"]
-        if square_j[r] >= width - width // 3:
-            s += [f"{c_r} right"]
-        if square_j[r] < width // 3:
-            s += [f"{c_r} left"]
-
-        for t, c_t in [
-            (k, color_id2name[square_c[k].item()]) for k in range(nb_squares)
-        ]:
-            if square_i[r] > square_i[t]:
-                s += [f"{c_r} below {c_t}"]
-            if square_i[r] < square_i[t]:
-                s += [f"{c_r} above {c_t}"]
-            if square_j[r] > square_j[t]:
-                s += [f"{c_r} right of {c_t}"]
-            if square_j[r] < square_j[t]:
-                s += [f"{c_r} left of {c_t}"]
-
-    return s
-
-
-######################################################################
-
-# Generates sequences
-
-
-def generate(
-    nb,
-    height,
-    width,
-    max_nb_squares=5,
-    max_nb_properties=10,
-    nb_colors=5,
-    pruner=None,
-):
-    assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1
-
-    descr = []
-
-    for n in range(nb):
-        # we want uniform over the combinations of 1 to max_nb_squares
-        # pixels of nb_colors
-        logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
-        dist = torch.distributions.categorical.Categorical(logits=logits)
-        nb_squares = dist.sample((1,)) + 1
-        # nb_squares = torch.randint(max_nb_squares, (1,)) + 1
-        square_position = torch.randperm(height * width)[:nb_squares]
-
-        # color 0 is white and reserved for the background
-        square_c = torch.randperm(nb_colors)[:nb_squares] + 1
-        square_i = square_position.div(width, rounding_mode="floor")
-        square_j = square_position % width
-
-        img = torch.zeros(height * width, dtype=torch.int64)
-        for k in range(nb_squares):
-            img[square_position[k]] = square_c[k]
-
-        # generates all the true properties
-
-        s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
-
-        if pruner is not None:
-            s = list(filter(pruner, s))
-
-        # picks at most max_nb_properties at random
-
-        nb_properties = torch.randint(max_nb_properties, (1,)) + 1
-        s = (
-            " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
-            + " <img> "
-            + " ".join([f"{color_id2name[n.item()]}" for n in img])
-        )
-
-        descr += [s]
-
-    return descr
-
-
-######################################################################
-
-# Extracts the image after <img> in descr as a 1x3xHxW tensor
-
-
-def descr2img(descr, height, width):
-    result = []
-
-    def token2color(t):
-        try:
-            return color_name2rgb[t]
-        except KeyError:
-            return [128, 128, 128]
-
-    for d in descr:
-        d = d.split("<img>")[1]
-        d = d.strip().split(" ")[: height * width]
-        d = d + ["<unk>"] * (height * width - len(d))
-        d = [token2color(t) for t in d]
-        img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width)
-        result.append(img)
-
-    return torch.cat(result, 0)
-
-
-######################################################################
-
-# Returns all the properties of the image after <img> in descr
-
-
-def descr2properties(descr, height, width):
-    if type(descr) == list:
-        return [descr2properties(d, height, width) for d in descr]
-
-    d = descr.split("<img>")
-    img_tokens = d[-1] if len(d) > 1 else ""
-    img_tokens = img_tokens.strip().split(" ")[: height * width]
-    if len(img_tokens) != height * width:
-        return []
-
-    seen = {}
-    for k, x in enumerate(img_tokens):
-        if x != color_id2name[0]:
-            if x in color_name2rgb:
-                if x in seen:
-                    return []
-            else:
-                return []
-            seen[x] = (color_name2id[x], k // width, k % width)
-
-    square_infos = tuple(zip(*seen.values()))
-
-    if square_infos:
-        square_c = torch.tensor(square_infos[0])
-        square_i = torch.tensor(square_infos[1])
-        square_j = torch.tensor(square_infos[2])
-    else:
-        square_c = torch.tensor([])
-        square_i = torch.tensor([])
-        square_j = torch.tensor([])
-
-    s = all_properties(height, width, len(seen), square_i, square_j, square_c)
-
-    return s
-
-
-######################################################################
-
-# Returns a triplet composed of (1) the total number of properties
-# before <img> in descr, (2) the total number of properties the image
-# after <img> verifies, and (3) the number of properties in (1) not in
-# (2)
-
-
-def nb_properties(descr, height, width, pruner=None):
-    if type(descr) == list:
-        return [nb_properties(d, height, width, pruner) for d in descr]
-
-    d = descr.split("<img>", 1)
-    if len(d) == 0:
-        return 0
-    d = d[0].strip().split("<sep>")
-    d = [x.strip() for x in d]
-
-    all_properties = set(descr2properties(descr, height, width))
-
-    if pruner is None:
-        requested_properties = set(d)
-    else:
-        requested_properties = set(filter(pruner, d))
-
-    missing_properties = requested_properties - all_properties
-
-    return (len(requested_properties), len(all_properties), len(missing_properties))
-
-
-######################################################################
-
-if __name__ == "__main__":
-    for n in range(16):
-        descr = generate(nb=1, height=12, width=16)
-
-        print(nb_properties(descr, height=12, width=16))
-
-        with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
-            for d in descr:
-                f.write(f"{d}\n\n")
-
-        img = descr2img(descr, height=12, width=16)
-        if img.size(0) == 1:
-            img = F.pad(img, (1, 1, 1, 1), value=64)
-
-        torchvision.utils.save_image(
-            img / 255.0,
-            f"picoclvr_example_{n:02d}.png",
-            padding=1,
-            nrow=4,
-            pad_value=0.8,
-        )
-
-    import time
-
-    start_time = time.perf_counter()
-    descr = generate(nb=1000, height=12, width=16)
-    end_time = time.perf_counter()
-    print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
-
-######################################################################
diff --git a/qmlp.py b/qmlp.py
deleted file mode 100755 (executable)
index abebfc1..0000000
--- a/qmlp.py
+++ /dev/null
@@ -1,378 +0,0 @@
-#!/usr/bin/env python
-
-# @XREMOTE_HOST: elk.fleuret.org
-# @XREMOTE_EXEC: python
-# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
-# @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
-# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
-# @XREMOTE_SEND: *.py *.sh
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-nb_quantization_levels = 101
-
-
-def quantize(x, xmin, xmax):
-    return (
-        ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
-        .long()
-        .clamp(min=0, max=nb_quantization_levels - 1)
-    )
-
-
-def dequantize(q, xmin, xmax):
-    return q / nb_quantization_levels * (xmax - xmin) + xmin
-
-
-######################################################################
-
-
-def generate_sets_and_params(
-    batch_nb_mlps,
-    nb_samples,
-    batch_size,
-    nb_epochs,
-    device=torch.device("cpu"),
-    print_log=False,
-    save_as_examples=False,
-):
-    data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
-    data_targets = torch.zeros(
-        batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
-    )
-
-    nb_rec = 8
-    nb_values = 2  # more increases the min-max gap
-
-    rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device)
-
-    while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
-        i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
-        nb = i.sum()
-        support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
-        support = support.sort(-1).values
-        support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
-
-        x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
-        y = (
-            (
-                (x[:, None, :, 0] >= support[:, :, None, 0]).long()
-                * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
-                * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
-                * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
-            )
-            .max(dim=1)
-            .values
-        )
-
-        data_input[i], data_targets[i], rec_support[i] = x, y, support
-
-    train_input, train_targets = (
-        data_input[:, :nb_samples],
-        data_targets[:, :nb_samples],
-    )
-    test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
-
-    q_train_input = quantize(train_input, -1, 1)
-    train_input = dequantize(q_train_input, -1, 1)
-
-    q_test_input = quantize(test_input, -1, 1)
-    test_input = dequantize(q_test_input, -1, 1)
-
-    if save_as_examples:
-        a = (
-            2
-            * torch.arange(nb_quantization_levels).float()
-            / (nb_quantization_levels - 1)
-            - 1
-        )
-        xf = torch.cat(
-            [
-                a[:, None, None].expand(
-                    nb_quantization_levels, nb_quantization_levels, 1
-                ),
-                a[None, :, None].expand(
-                    nb_quantization_levels, nb_quantization_levels, 1
-                ),
-            ],
-            2,
-        )
-        xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1)
-        print(f"{xf.size()=} {x.size()=}")
-        yf = (
-            (
-                (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long()
-                * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long()
-                * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long()
-                * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long()
-            )
-            .max(dim=1)
-            .values
-        )
-
-        full_input, full_targets = xf, yf
-
-        q_full_input = quantize(full_input, -1, 1)
-        full_input = dequantize(q_full_input, -1, 1)
-
-        for k in range(q_full_input[:10].size(0)):
-            with open(f"example_full_{k:04d}.dat", "w") as f:
-                for u, c in zip(full_input[k], full_targets[k]):
-                    f.write(f"{c} {u[0].item()} {u[1].item()}\n")
-
-        for k in range(q_train_input[:10].size(0)):
-            with open(f"example_train_{k:04d}.dat", "w") as f:
-                for u, c in zip(train_input[k], train_targets[k]):
-                    f.write(f"{c} {u[0].item()} {u[1].item()}\n")
-
-    hidden_dim = 32
-    w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
-    b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
-    w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(
-        hidden_dim
-    )
-    b2 = torch.zeros(batch_nb_mlps, 2, device=device)
-
-    w1.requires_grad_()
-    b1.requires_grad_()
-    w2.requires_grad_()
-    b2.requires_grad_()
-    optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
-
-    criterion = nn.CrossEntropyLoss()
-    criterion.to(device)
-
-    for k in range(nb_epochs):
-        acc_train_loss = 0.0
-        nb_train_errors = 0
-
-        for input, targets in zip(
-            train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
-        ):
-            h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
-            h = F.relu(h)
-            output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
-            loss = F.cross_entropy(
-                output.reshape(-1, output.size(-1)), targets.reshape(-1)
-            )
-            acc_train_loss += loss.item() * input.size(0)
-
-            wta = output.argmax(-1)
-            nb_train_errors += (wta != targets).long().sum(-1)
-
-            optimizer.zero_grad()
-            loss.backward()
-            optimizer.step()
-
-        with torch.no_grad():
-            for p in [w1, b1, w2, b2]:
-                m = (
-                    torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
-                ).long()
-                pq = quantize(p, -2, 2)
-                p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
-
-        train_error = nb_train_errors / train_input.size(1)
-        acc_train_loss = acc_train_loss / train_input.size(1)
-
-        # print(f"{k=} {acc_train_loss=} {train_error=}")
-
-    acc_test_loss = 0
-    nb_test_errors = 0
-
-    for input, targets in zip(
-        test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
-    ):
-        h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
-        h = F.relu(h)
-        output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
-        loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
-        acc_test_loss += loss.item() * input.size(0)
-
-        wta = output.argmax(-1)
-        nb_test_errors += (wta != targets).long().sum(-1)
-
-    test_error = nb_test_errors / test_input.size(1)
-    q_params = torch.cat(
-        [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
-    )
-    q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
-        batch_nb_mlps, -1
-    )
-    q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
-        batch_nb_mlps, -1
-    )
-
-    return q_train_set, q_test_set, q_params, test_error
-
-
-######################################################################
-
-
-def evaluate_q_params(
-    q_params,
-    q_set,
-    batch_size=25,
-    device=torch.device("cpu"),
-    nb_mlps_per_batch=1024,
-    save_as_examples=False,
-):
-    errors = []
-    nb_mlps = q_params.size(0)
-
-    for n in range(0, nb_mlps, nb_mlps_per_batch):
-        batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n)
-        batch_q_params = q_params[n : n + batch_nb_mlps]
-        batch_q_set = q_set[n : n + batch_nb_mlps]
-        hidden_dim = 32
-        w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
-        b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
-        w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
-        b2 = torch.empty(batch_nb_mlps, 2, device=device)
-
-        with torch.no_grad():
-            k = 0
-            for p in [w1, b1, w2, b2]:
-                print(f"{p.size()=}")
-                x = dequantize(
-                    batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2
-                ).view(p.size())
-                p.copy_(x)
-                k += p.numel() // batch_nb_mlps
-
-        batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
-        data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
-        data_targets = batch_q_set[:, :, 2].to(device)
-
-        print(f"{data_input.size()=} {data_targets.size()=}")
-
-        criterion = nn.CrossEntropyLoss()
-        criterion.to(device)
-
-        acc_loss = 0.0
-        nb_errors = 0
-
-        for input, targets in zip(
-            data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
-        ):
-            h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
-            h = F.relu(h)
-            output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
-            loss = F.cross_entropy(
-                output.reshape(-1, output.size(-1)), targets.reshape(-1)
-            )
-            acc_loss += loss.item() * input.size(0)
-            wta = output.argmax(-1)
-            nb_errors += (wta != targets).long().sum(-1)
-
-        errors.append(nb_errors / data_input.size(1))
-        acc_loss = acc_loss / data_input.size(1)
-
-    return torch.cat(errors)
-
-
-######################################################################
-
-
-def generate_sequence_and_test_set(
-    nb_mlps,
-    nb_samples,
-    batch_size,
-    nb_epochs,
-    device,
-    nb_mlps_per_batch=1024,
-):
-    seqs, q_test_sets, test_errors = [], [], []
-
-    for n in range(0, nb_mlps, nb_mlps_per_batch):
-        q_train_set, q_test_set, q_params, test_error = generate_sets_and_params(
-            batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n),
-            nb_samples=nb_samples,
-            batch_size=batch_size,
-            nb_epochs=nb_epochs,
-            device=device,
-        )
-
-        seqs.append(
-            torch.cat(
-                [
-                    q_train_set,
-                    q_train_set.new_full(
-                        (
-                            q_train_set.size(0),
-                            1,
-                        ),
-                        nb_quantization_levels,
-                    ),
-                    q_params,
-                ],
-                dim=-1,
-            )
-        )
-
-        q_test_sets.append(q_test_set)
-        test_errors.append(test_error)
-
-    seq = torch.cat(seqs)
-    q_test_set = torch.cat(q_test_sets)
-    test_error = torch.cat(test_errors)
-
-    return seq, q_test_set, test_error
-
-
-######################################################################
-
-if __name__ == "__main__":
-    import time
-
-    batch_nb_mlps, nb_samples = 128, 250
-
-    generate_sets_and_params(
-        batch_nb_mlps=10,
-        nb_samples=nb_samples,
-        batch_size=25,
-        nb_epochs=100,
-        device=torch.device("cpu"),
-        print_log=False,
-        save_as_examples=True,
-    )
-
-    exit(0)
-
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-    start_time = time.perf_counter()
-
-    data = []
-
-    seq, q_test_set, test_error = generate_sequence_and_test_set(
-        nb_mlps=batch_nb_mlps,
-        nb_samples=nb_samples,
-        device=device,
-        batch_size=25,
-        nb_epochs=250,
-        nb_mlps_per_batch=17,
-    )
-
-    end_time = time.perf_counter()
-    print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
-
-    q_train_set = seq[:, : nb_samples * 3]
-    q_params = seq[:, nb_samples * 3 + 1 :]
-    print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
-    error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
-    print(f"train {error_train*100}%")
-    error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
-    print(f"test {error_test*100}%")
diff --git a/rpl.py b/rpl.py
deleted file mode 100755 (executable)
index b848afa..0000000
--- a/rpl.py
+++ /dev/null
@@ -1,177 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-
-def rpl_exec(program, stack):
-    stack = stack.copy()
-    for op in program:
-        if op == "add":
-            if len(stack) > 1:
-                a, b = stack.pop(), stack.pop()
-                stack.append(a + b)
-        elif op == "min":
-            if len(stack) > 1:
-                a, b = stack.pop(), stack.pop()
-                stack.append(min(a, b))
-        elif op == "max":
-            if len(stack) > 1:
-                a, b = stack.pop(), stack.pop()
-                stack.append(max(a, b))
-        elif op == "swp":
-            if len(stack) > 1:
-                a, b = stack.pop(), stack.pop()
-                stack.append(a)
-                stack.append(b)
-        elif op == "rep":
-            if len(stack) > 1:
-                a, b = stack.pop(), stack.pop()
-                stack += [b] * a
-        elif op == "dup":
-            if len(stack) > 0:
-                a = stack.pop()
-                stack.append(a)
-                stack.append(a)
-        elif op == "del":
-            if len(stack) > 0:
-                a = stack.pop()
-        else:
-            raise ValueError(f"Unknown instruction {op}")
-
-    return stack
-
-
-rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"]
-
-######################################################################
-
-
-def generate(
-    nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5
-):
-    prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item()
-
-    while True:
-        no_empty_stack = True
-        prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))]
-
-        result = []
-        for _ in range(nb_runs):
-            stack = [
-                x.item() for x in torch.randint(max_input + 1, (nb_starting_values,))
-            ]
-            result_stack = rpl_exec(prog, stack)
-            if len(result_stack) == 0:
-                no_empty_stack = False
-            result = result + ["<in>"] + stack + ["<out>"] + result_stack
-
-        result = result + ["<prg>"] + prog
-        result = result + ["<end>"]
-
-        if no_empty_stack and (
-            nb_result_values_max is None or len(result_stack) <= nb_result_values_max
-        ):
-            break
-
-    return result
-
-
-def next_marker(seq, tokens, start=0):
-    pos = None
-    for t in tokens:
-        try:
-            i = seq.index(t, start)
-            if pos is None or i < pos:
-                pos = i
-        except ValueError:
-            pass
-    return pos
-
-
-def decompose(seq):
-    io = []
-    k = 0
-    while seq[k] == "<in>":
-        o = next_marker(seq, ["<out>"], start=k + 1)
-        if o is None:
-            raise ValueError("Missing output markers (should be correct in the prompt)")
-        e = next_marker(seq, ["<in>", "<prg>"], start=o)
-        if e is None:
-            raise ValueError(
-                "Missing input/output markers (should be correct in the prompt)"
-            )
-        try:
-            io.append(
-                ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]])
-            )
-        except ValueError:
-            raise ValueError(
-                "Invalid input/output value (should be correct in the prompt)"
-            )
-
-        k = e
-
-    if seq[k] == "<prg>":
-        e = next_marker(seq, ["<end>"], start=k)
-        if e is None:
-            prog = []
-        else:
-            prog = seq[k + 1 : e]
-    else:
-        raise ValueError("Missing <prg> (it should be in the prompt)")
-
-    return prog, io
-
-
-def stack_distance(target_stack, result_stack):
-    return abs(len(result_stack) - len(target_stack)) + sum(
-        [0 if x == y else 1 for x, y in zip(result_stack, target_stack)]
-    )
-
-
-def compute_nb_errors(seq):
-    prog, io = decompose(seq)
-
-    nb_total, nb_errors = 0, 0
-
-    stacks = []
-
-    if len(set(prog) - set(rpl_ops)) > 0:
-        # Program is not valid, we count 100% error
-        for start_stack, target_stack in io:
-            stacks.append((start_stack, target_stack, ["N/A"], False))
-            nb_total += len(target_stack)
-            nb_errors += len(target_stack)
-
-    else:
-        # Program is valid
-        for start_stack, target_stack in io:
-            result_stack = rpl_exec(prog, start_stack)
-            nb_total += len(target_stack)
-            e = stack_distance(target_stack, result_stack)
-            nb_errors += e
-            stacks.append((start_stack, target_stack, result_stack, e == 0))
-
-    return nb_total, nb_errors, prog, stacks
-
-
-######################################################################
-
-if __name__ == "__main__":
-    seq = generate()
-    print(seq)
-    seq[3] = 7
-    print(seq)
-    print(compute_nb_errors(seq))
diff --git a/snake.py b/snake.py
deleted file mode 100755 (executable)
index 8a16f9f..0000000
--- a/snake.py
+++ /dev/null
@@ -1,132 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import torch, torchvision
-import torch.nn.functional as F
-
-
-def generate_sequences(
-    nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
-):
-    worlds = torch.randint(nb_colors, (nb, height, width), device=device)
-    world_prior_visits = torch.zeros(nb, height, width, device=device)
-
-    # nb x 2
-    snake_position = torch.cat(
-        (
-            torch.randint(height, (nb, 1), device=device),
-            torch.randint(width, (nb, 1), device=device),
-        ),
-        1,
-    )
-    snake_direction = torch.randint(4, (nb,), device=device)
-    sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
-    sequences_prior_visits = torch.zeros(
-        nb, 2 * length, device=device, dtype=torch.int64
-    )
-    i = torch.arange(nb, device=device)  # [:,None]
-
-    for l in range(length):
-        # nb x 3
-        snake_next_direction = torch.cat(
-            (
-                (snake_direction[:, None] - 1) % 4,
-                snake_direction[:, None],
-                (snake_direction[:, None] + 1) % 4,
-            ),
-            1,
-        )
-
-        # nb x 3
-        vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
-        vw = snake_next_direction % 2 * (snake_next_direction - 2)
-
-        # nb x 3 x 2
-        snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
-        snake_next_position = snake_position[:, None, :] + snake_next_speed
-
-        # nb x 3
-        val = torch.logical_and(
-            torch.logical_and(
-                snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
-            ),
-            torch.logical_and(
-                snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
-            ),
-        ).float()
-        val = (
-            # The multiplicative factors bias toward moving forward
-            torch.rand_like(val)
-            * val
-            * torch.tensor([[1.0, 2.0, 1.0]], device=device)
-        )
-
-        # nb
-        j = val.argmax(1)
-        snake_direction = snake_next_direction[i, j]
-
-        sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
-        sequences_prior_visits[:, 2 * l] = world_prior_visits[
-            i, snake_position[:, 0], snake_position[:, 1]
-        ]
-        if l < prompt_length:
-            world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
-        sequences[:, 2 * l + 1] = snake_direction
-
-        # nb x 2
-        snake_position = snake_next_position[i, j]
-
-    return sequences, sequences_prior_visits, worlds, world_prior_visits
-
-
-# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
-# exit(0)
-
-
-def solver(input, ar_mask):
-    for n in range(input.size(0)):
-        i, j, memory = 0, 0, {}
-        # print(input[n])
-        # print(ar_mask[n])
-        for l in range(input.size(1) // 2):
-            if ar_mask[n, 2 * l] == 1:
-                if memory.get((i, j)) is None:
-                    input[n, 2 * l] = -1
-                else:
-                    input[n, 2 * l] = memory[(i, j)]
-            else:
-                # print(f'@3 {memory=}')
-                if memory.get((i, j)) is None:
-                    memory[(i, j)] = input[n, 2 * l]
-                else:
-                    assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
-            # print(f'@1 {i=} {j=}')
-            d = input[n, 2 * l + 1].item()
-            i += (d + 1) % 2 * (d - 1)
-            j += d % 2 * (d - 2)
-            # print(f'@2 {i=} {j=}')
-
-
-def seq2str(seq):
-    return "".join(["NESW123456789"[i] for i in seq])
-
-
-######################################################################
-
-if __name__ == "__main__":
-    train_input, train_prior_visits, _, _ = generate_sequences(
-        nb=20,
-        height=9,
-        width=12,
-        nb_colors=5,
-        length=50,
-        prompt_length=100,
-    )
-
-    print([seq2str(s) for s in train_input])
-
-######################################################################
diff --git a/stack.py b/stack.py
deleted file mode 100755 (executable)
index 543f04e..0000000
--- a/stack.py
+++ /dev/null
@@ -1,107 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import torch, torchvision
-
-######################################################################
-
-# CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
-# CODE_VAL=val + 2 * nb_stacks
-
-
-def generate_sequences(
-    nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
-):
-    stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
-    stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
-    k = torch.arange(nb)
-    result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
-    recorded_stack_counts = torch.zeros(
-        nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
-    )
-
-    for t in range(nb_steps):
-        op = torch.randint(2, (nb,))
-        st = torch.randint(nb_stacks, (nb,))
-        op = op * (stack_counts[k, st] > 0)
-        if values is None:
-            val_push = torch.randint(10**nb_digits, (nb,))
-        else:
-            val_push = values[torch.randint(values.size(0), (nb,))]
-        val_pop = stack[
-            k,
-            st,
-            (stack_counts[k, st] - 1).clamp(min=0),
-        ]
-        stack[k, st, stack_counts[k, st]] = val_push
-        recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
-        stack_counts[k[op == 0], st[op == 0]] += 1
-        stack_counts[k[op == 1], st[op == 1]] -= 1
-        result[:, (1 + nb_digits) * t] = st * 2 + op
-        for d in range(nb_digits):
-            result[:, (1 + nb_digits) * t + 1 + d] = (
-                (op * val_pop + (1 - op) * val_push) // (10**d)
-            ) % 10 + 2 * nb_stacks
-
-    return result.to(device), recorded_stack_counts.to(device)
-
-
-def remove_popped_values(seq, nb_stacks, nb_digits):
-    m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
-    for d in range(nb_digits):
-        k = d + 1
-        seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
-
-
-def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
-    assert seq.size(0) % (1 + nb_digits) == 0
-    s = ""
-    for t in range(seq.size(0) // (1 + nb_digits)):
-        n_op = seq[(1 + nb_digits) * t]
-        if t > 0:
-            s += " "
-        if recorded_stack_counts is not None:
-            s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
-        s += f"POP" if n_op % 2 == 1 else f"PSH"
-        if nb_stacks > 1:
-            s += f"_{n_op//2}"
-        for d in range(nb_digits):
-            if seq[(1 + nb_digits) * t + 1 + d] == -1:
-                s += " ?"
-            else:
-                s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
-    return s
-
-
-######################################################################
-
-if __name__ == "__main__":
-    nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
-    seq, recorded_stack_counts = generate_sequences(
-        nb=nb,
-        nb_steps=nb_steps,
-        nb_stacks=nb_stacks,
-        nb_digits=nb_digits,
-    )
-
-    for n in range(min(10, seq.size(0))):
-        print(
-            seq_to_str(
-                seq[n],
-                nb_stacks=nb_stacks,
-                nb_digits=nb_digits,
-                recorded_stack_counts=recorded_stack_counts[n],
-            )
-        )
-        # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
-
-    print("-- PREPARED FOR TEST -----------------")
-
-    remove_popped_values(seq, nb_stacks, nb_digits)
-
-    for n in range(min(10, seq.size(0))):
-        print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
index 5d9a018..77493a8 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -75,2024 +75,6 @@ class Task:
         pass
 
 
-class TaskFromFile(Task):
-    def tensorize(self, pairs, shuffle):
-        len_max = max([len(x[0]) for x in pairs])
-
-        input = torch.cat(
-            [
-                torch.tensor(
-                    [
-                        [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
-                        for s in pairs
-                    ]
-                )
-            ],
-            0,
-        ).to("cpu")
-
-        pred_mask = torch.cat(
-            [
-                torch.tensor(
-                    [
-                        [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
-                        for s in pairs
-                    ]
-                )
-            ],
-            0,
-        ).to("cpu")
-
-        if shuffle:
-            i = torch.randperm(input.size(0))
-            input = input[i].contiguous()
-            pred_mask = pred_mask[i].contiguous()
-
-        return input, pred_mask
-
-    # trim all the tensors in the tuple z to remove as much token from
-    # left and right in the first tensor. If z is a tuple, all its
-    # elements are trimed according to the triming for the first
-    def trim(self, z, token="#"):
-        n = self.char2id[token]
-        if type(z) == tuple:
-            x = z[0]
-            i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
-            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
-            return tuple([t[:, a:b] for t in z])
-        else:
-            i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
-            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
-            return z[:, a:b]
-
-    def __init__(
-        self,
-        train_filename,
-        test_filename,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        shuffle=False,
-        device=torch.device("cpu"),
-    ):
-        self.batch_size = batch_size
-        self.device = device
-
-        def read_file(filename, nb=-1):
-            pairs = []
-            with open(filename, "r") as f:
-                while True:
-                    sequence = f.readline().strip()
-                    if not sequence:
-                        break
-                    pred_mask = f.readline().strip()
-                    assert len(sequence) == len(pred_mask)
-                    assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
-                    pairs.append((sequence, pred_mask))
-                    if len(pairs) == nb:
-                        break
-
-            if nb > 0:
-                pairs = pairs[:nb]
-                assert len(pairs) == nb
-
-            return pairs
-
-        train_pairs = read_file(train_filename, nb_train_samples)
-        test_pairs = read_file(test_filename, nb_test_samples)
-
-        symbols = ["#"] + list(
-            set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
-        )
-        self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
-        self.id2char = dict([(n, c) for c, n in self.char2id.items()])
-
-        self.train_input, self.train_pred_masks = self.tensorize(
-            train_pairs, shuffle=shuffle
-        )
-        self.test_input, self.test_pred_masks = self.tensorize(
-            test_pairs, shuffle=shuffle
-        )
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield self.trim(batch).to(self.device)
-
-    def vocabulary_size(self):
-        return len(self.char2id)
-
-    def tensor2str(self, t):
-        return ["".join([self.id2char[x.item()] for x in s]) for s in t]
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        correct = self.trim(self.test_input[:1000]).to(self.device)
-        result = correct.clone()
-        pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
-        ar_mask = (pred_mask > 0).long()
-        result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
-
-        logger(f"----------------------------------------------------------")
-
-        for e in self.tensor2str(result[:50]):
-            logger(f"test_before {e}")
-
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-
-        logger(f"----------------------------------------------------------")
-
-        for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
-            logger(f"test_after  {e}")
-            logger(f"correct     {c}")
-
-        logger(f"----------------------------------------------------------")
-
-        err_mask = (pred_mask == 2).long()
-        nb_total = err_mask.sum().item()
-        nb_correct = ((correct == result).long() * err_mask).sum().item()
-
-        logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
-        logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
-
-
-####################
-
-import problems
-
-
-class SandBox(Task):
-    def __init__(
-        self,
-        problem,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        logger=None,
-        device=torch.device("cpu"),
-        max_nb_codes=1024,
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.device = device
-        self.problem = problem
-
-        self.train_input, self.train_ar_mask = self.problem.generate_sequences(
-            nb_train_samples
-        )
-        self.test_input, self.test_ar_mask = self.problem.generate_sequences(
-            nb_test_samples
-        )
-
-        self.train_input, self.train_ar_mask = self.train_input.to(
-            device
-        ), self.train_ar_mask.to(device)
-        self.test_input, self.test_ar_mask = self.test_input.to(
-            device
-        ), self.test_ar_mask.to(device)
-
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
-        # A bit of paranoia never hurts
-        assert self.nb_codes <= max_nb_codes
-        assert self.train_input.min() >= 0
-        assert self.test_input.min() >= 0
-        assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
-            (0,),
-            (1,),
-            (0, 1),
-        }
-        assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
-            (0,),
-            (1,),
-            (0, 1),
-        }
-
-        if logger is not None:
-            for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
-                logger(f"train_sequences {self.problem.seq2str(s)}")
-                a = "".join(["01"[x.item()] for x in a])
-                logger(f"                {a}")
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
-    ):
-        def compute_accuracy(input, ar_mask, logger=None):
-            input, ar_mask = input[:nmax], ar_mask[:nmax]
-            result = input.clone() * (1 - ar_mask)
-
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis,
-                progress_bar_desc=None,
-                device=self.device,
-            )
-
-            log_ground_truth = ar_mask.min() == 0
-
-            if logger is not None:
-                for sp, st in zip(result[:10], input[:10]):
-                    logger(
-                        f"test_sequences {n_epoch} prediction   {self.problem.seq2str(sp)}"
-                    )
-                    if log_ground_truth:
-                        logger(
-                            f"               {n_epoch} ground truth {self.problem.seq2str(st)}"
-                        )
-
-            nb_total, nb_correct = self.problem.compute_nb_correct(
-                input, ar_mask, result
-            )
-
-            # nb_total = ar_mask.sum().item()
-            # nb_correct = ((result == input).long() * ar_mask).sum().item()
-
-            return nb_total, nb_correct
-
-        train_nb_total, train_nb_correct = compute_accuracy(
-            self.train_input, self.train_ar_mask
-        )
-
-        logger(
-            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
-        )
-
-        test_nb_total, test_nb_correct = compute_accuracy(
-            self.test_input, self.test_ar_mask, logger
-        )
-
-        logger(
-            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
-        )
-
-        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
-        if save_attention_image is not None:
-            for k in range(10):
-                ns = torch.randint(self.test_input.size(0), (1,)).item()
-                input = self.test_input[ns : ns + 1].clone()
-
-                with torch.autograd.no_grad():
-                    t = model.training
-                    model.eval()
-                    # model.record_attention(True)
-                    model(BracketedSequence(input))
-                    model.train(t)
-                    # ram = model.retrieve_attention()
-                    # model.record_attention(False)
-
-                # tokens_output = [c for c in self.problem.seq2str(input[0])]
-                # tokens_input = ["n/a"] + tokens_output[:-1]
-                # for n_head in range(ram[0].size(1)):
-                # filename = os.path.join(
-                # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
-                # )
-                # attention_matrices = [m[0, n_head] for m in ram]
-                # save_attention_image(
-                # filename,
-                # tokens_input,
-                # tokens_output,
-                # attention_matrices,
-                # k_top=10,
-                ##min_total_attention=0.9,
-                # token_gap=12,
-                # layer_gap=50,
-                # )
-                # logger(f"wrote {filename}")
-
-
-######################################################################
-
-import picoclvr
-
-
-class PicoCLVR(Task):
-    # Make a tensor from a list of strings
-    def tensorize(self, descr):
-        token_descr = [s.strip().split(" ") for s in descr]
-        l = max([len(s) for s in token_descr])
-        token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
-        id_descr = [[self.token2id[u] for u in s] for s in token_descr]
-        return torch.tensor(id_descr, device=self.device)
-
-    # Make a list of strings from a tensor
-    def detensorize(self, x):
-        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
-
-    # trim all the tensors in the tuple z to remove as much token from
-    # left and right in the first tensor. If z is a tuple, all its
-    # elements are trimed according to the triming for the first
-    def trim(self, z, token="<nul>"):
-        n = self.token2id[token]
-        if type(z) == tuple:
-            x = z[0]
-            i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
-            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
-            return tuple([t[:, a:b] for t in z])
-        else:
-            i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
-            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
-            return z[:, a:b]
-
-    ######################
-
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        height,
-        width,
-        nb_colors=5,
-        logger=None,
-        device=torch.device("cpu"),
-        pruner_train=None,
-        pruner_eval=None,
-    ):
-        super().__init__()
-
-        def generate_descr(nb, cache_suffix, pruner):
-            return picoclvr.generate(
-                nb,
-                height=self.height,
-                width=self.width,
-                nb_colors=nb_colors,
-                pruner=pruner,
-            )
-
-        self.height = height
-        self.width = width
-        self.batch_size = batch_size
-        self.device = device
-        self.pruner_train = pruner_train
-        self.pruner_eval = pruner_eval
-
-        if logger is not None:
-            logger(
-                f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
-            )
-
-        self.train_descr = generate_descr(
-            nb_train_samples, "train", pruner=self.pruner_train
-        )
-        self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
-
-        # Build the tokenizer
-        tokens = {"<nul>", "<img>"}
-        for d in [self.train_descr, self.test_descr]:
-            for s in d:
-                for t in s.strip().split(" "):
-                    tokens.add(t)
-        # make this set a sorted list to get the same tensors given
-        # the same descr
-        tokens = list(tokens)
-        tokens.sort()
-        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
-        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
-        self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
-
-        # Tokenize the train and test sets
-        self.train_input = self.tensorize(self.train_descr)
-        self.test_input = self.tensorize(self.test_descr)
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
-        ):
-            yield self.trim(batch)
-
-    def vocabulary_size(self):
-        return len(self.token2id)
-
-    def compute_missing_properties(
-        self, n_epoch, model, logger, deterministic_synthesis, pruner=None
-    ):
-        acc_nb_requested_properties = []
-        acc_nb_missing_properties = []
-        acc_nb_results = 0
-
-        for input in tqdm.tqdm(
-            self.test_input.split(self.batch_size),
-            dynamic_ncols=True,
-            desc=f"test-properties",
-        ):
-            result = input.clone()
-            ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
-            result = (1 - ar_mask) * result + ar_mask * self.t_nul
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis,
-                progress_bar_desc=None,
-                device=self.device,
-            )
-
-            result_descr = self.detensorize(result)
-            np = picoclvr.nb_properties(
-                result_descr,
-                height=self.height,
-                width=self.width,
-                pruner=pruner,
-            )
-            nb_requested_properties, _, nb_missing_properties = zip(*np)
-            acc_nb_requested_properties += nb_requested_properties
-            acc_nb_missing_properties += nb_missing_properties
-            acc_nb_results += len(result_descr)
-
-        nb_requested_properties = sum(acc_nb_requested_properties)
-        nb_missing_properties = sum(acc_nb_missing_properties)
-
-        prefix = "" if pruner is None else "pruned_"
-        logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
-        logger(
-            f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
-        )
-        logger(
-            f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
-        )
-
-        logger(
-            f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
-        )
-
-    ######################################################################
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
-
-        if self.pruner_eval is not None:
-            self.compute_missing_properties(n_epoch, model, self.pruner_eval)
-
-        nb_tokens_to_generate = self.height * self.width + 3
-        result_descr = []
-        nb_per_primer = 8
-        primer = []
-
-        for primer_descr in [
-            "red above green <sep> green top <sep> blue right of red",
-            "there is red <sep> there is yellow <sep> there is blue",
-            "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
-            "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
-        ]:
-            primer += [primer_descr + " <img>"] * nb_per_primer
-
-        result = self.tensorize(primer)
-        fill = result.new_full(
-            result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
-        )
-        result = torch.cat((result, fill), 1)
-        ar_mask = (result == self.t_nul).long()
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-        result_descr = self.detensorize(result)
-
-        np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
-
-        acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
-        acc_nb_results = len(result_descr)
-
-        nb_requested_properties = sum(acc_nb_requested_properties)
-        nb_missing_properties = sum(acc_nb_missing_properties)
-
-        prefix = "demo_"
-        logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
-        logger(
-            f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
-        )
-        logger(
-            f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
-        )
-
-        img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
-
-        if img.dim() == 5:
-            if img.size(1) == 1:
-                img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
-            else:
-                img = torch.cat(
-                    [
-                        torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
-                        for x in img
-                    ],
-                    0,
-                )
-
-        image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
-        torchvision.utils.save_image(
-            img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
-        )
-        logger(f"wrote {image_name}")
-
-
-######################################################################
-
-
-class MNIST(Task):
-    def __init__(
-        self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
-    ):
-        super().__init__()
-
-        self.nb_train_samples = (nb_train_samples,)
-        self.nb_test_samples = (nb_test_samples,)
-        self.batch_size = batch_size
-        self.device = device
-        data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
-        self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
-        data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
-        self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
-
-    def vocabulary_size(self):
-        return 256
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
-        ar_mask = torch.full_like(results, 1)
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            results,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-        image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
-        torchvision.utils.save_image(
-            1 - results.reshape(-1, 1, 28, 28) / 255.0,
-            image_name,
-            nrow=16,
-            pad_value=0.8,
-        )
-        logger(f"wrote {image_name}")
-
-
-######################################################################
-
-import maze
-
-
-class Maze(Task):
-    def map2seq(self, *m):
-        return torch.cat([x.flatten(1) for x in m], 1)
-
-    def seq2map(self, s):
-        s = s.reshape(s.size(0), -1, self.height, self.width)
-        return (s[:, k] for k in range(s.size(1)))
-
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        height,
-        width,
-        nb_walls,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.height = height
-        self.width = width
-        self.device = device
-
-        train_mazes, train_paths, _ = maze.create_maze_data(
-            nb_train_samples,
-            height=height,
-            width=width,
-            nb_walls=nb_walls,
-            progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
-        )
-        self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
-
-        test_mazes, test_paths, _ = maze.create_maze_data(
-            nb_test_samples,
-            height=height,
-            width=width,
-            nb_walls=nb_walls,
-            progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
-        )
-        self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
-
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def compute_error(
-        self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
-    ):
-        model_device = next(model.parameters()).device
-        nb_total, nb_correct = 0, 0
-        count = torch.zeros(
-            self.width * self.height,
-            self.width * self.height,
-            device=model_device,
-            dtype=torch.int64,
-        )
-
-        for input in self.batches(split, nb_to_use):
-            input = input.to(model_device)
-            result = input.clone()
-            ar_mask = result.new_zeros(result.size())
-            ar_mask[:, self.height * self.width :] = 1
-            result *= 1 - ar_mask
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis,
-                progress_bar_desc=None,
-                device=self.device,
-            )
-            mazes, paths = self.seq2map(result)
-            path_correctness = maze.path_correctness(mazes, paths)
-            nb_correct += path_correctness.long().sum()
-            nb_total += mazes.size(0)
-
-            optimal_path_lengths = (
-                (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
-            )
-            predicted_path_lengths = (
-                (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
-            )
-            optimal_path_lengths = optimal_path_lengths[path_correctness]
-            predicted_path_lengths = predicted_path_lengths[path_correctness]
-            count[optimal_path_lengths, predicted_path_lengths] += 1
-
-        if count.max() == 0:
-            count = None
-        else:
-            count = count[
-                : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
-            ]
-
-        return nb_total, nb_correct, count
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        train_nb_total, train_nb_correct, count = self.compute_error(
-            model,
-            "train",
-            nb_to_use=1000,
-            deterministic_synthesis=deterministic_synthesis,
-        )
-        logger(
-            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
-        )
-
-        test_nb_total, test_nb_correct, count = self.compute_error(
-            model,
-            "test",
-            nb_to_use=1000,
-            deterministic_synthesis=deterministic_synthesis,
-        )
-        logger(
-            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
-        )
-
-        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
-        if count is not None:
-            proportion_optimal = count.diagonal().sum().float() / count.sum()
-            logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
-            with open(
-                os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
-            ) as f:
-                for i in range(count.size(0)):
-                    for j in range(count.size(1)):
-                        eol = " " if j < count.size(1) - 1 else "\n"
-                        f.write(f"{count[i,j]}{eol}")
-
-        input = self.test_input[:48].to(next(model.parameters()).device)
-        result = input.clone()
-        ar_mask = result.new_zeros(result.size())
-        ar_mask[:, self.height * self.width :] = 1
-        result *= 1 - ar_mask
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-
-        mazes, paths = self.seq2map(input)
-        _, predicted_paths = self.seq2map(result)
-
-        filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
-        maze.save_image(
-            filename,
-            mazes=mazes,
-            target_paths=paths,
-            predicted_paths=predicted_paths,
-            path_correct=maze.path_correctness(mazes, predicted_paths),
-            path_optimal=maze.path_optimality(paths, predicted_paths),
-        )
-        logger(f"wrote {filename}")
-
-
-######################################################################
-
-
-import snake
-
-
-class Snake(Task):
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        height,
-        width,
-        nb_colors,
-        length,
-        prompt_length,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.height = height
-        self.width = width
-        self.device = device
-        self.prompt_length = prompt_length
-
-        self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
-            nb_train_samples,
-            height,
-            width,
-            nb_colors,
-            length,
-            prompt_length,
-            self.device,
-        )
-        self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
-            nb_test_samples,
-            height,
-            width,
-            nb_colors,
-            length,
-            prompt_length,
-            self.device,
-        )
-
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        def compute_nb_correct(input, prior_visits):
-            result = input.clone()
-            i = torch.arange(result.size(1), device=result.device)[None, :]
-            ar_mask = (
-                torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
-                .long()
-                .expand_as(result)
-            )
-            result *= 1 - ar_mask
-
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis,
-                device=self.device,
-            )
-
-            nb_total = ((prior_visits > 0) * ar_mask).sum()
-
-            nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
-
-            return nb_total, nb_correct
-
-        test_nb_total, test_nb_correct = compute_nb_correct(
-            self.test_input[:1000], self.test_prior_visits[:1000]
-        )
-
-        logger(
-            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
-        )
-
-        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
-
-######################################################################
-
-
-import stack
-
-
-class Stack(Task):
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        logger,
-        nb_steps,
-        nb_stacks,
-        nb_digits,
-        fraction_values_for_train=None,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.nb_steps = nb_steps
-        self.nb_stacks = nb_stacks
-        self.nb_digits = nb_digits
-        self.device = device
-
-        if fraction_values_for_train is None:
-            values_for_train = None
-            values_for_test = None
-        else:
-            all = torch.randperm(10**nb_digits)
-            nb_for_train = int(all.size(0) * fraction_values_for_train)
-            values_for_train = all[:nb_for_train]
-            values_for_test = all[nb_for_train:]
-
-        self.train_input, self.train_stack_counts = stack.generate_sequences(
-            nb_train_samples,
-            nb_steps,
-            nb_stacks,
-            nb_digits,
-            values_for_train,
-            self.device,
-        )
-
-        self.test_input, self.test_stack_counts = stack.generate_sequences(
-            nb_test_samples,
-            nb_steps,
-            nb_stacks,
-            nb_digits,
-            values_for_test,
-            self.device,
-        )
-
-        i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
-        counts = self.test_stack_counts.flatten()[i.flatten()]
-        counts = F.one_hot(counts).sum(0)
-        logger(f"test_pop_stack_counts {counts}")
-
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        def compute_nb_correct(input):
-            result = input.clone()
-            stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
-            ar_mask = (result != input).long()
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis,
-                device=self.device,
-            )
-
-            errors = ((result != input).long() * ar_mask).reshape(
-                -1, 1 + self.nb_digits
-            )
-            ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
-
-            nb_total = ar_mask.max(1).values.sum()
-            nb_correct = nb_total - errors.max(1).values.sum()
-
-            return nb_total, nb_correct
-
-        test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
-
-        logger(
-            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
-        )
-
-        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
-        ##############################################################
-        # Log a few generated sequences
-        input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
-        result = input.clone()
-        stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
-        ar_mask = (result != input).long()
-
-        # for n in range(result.size(0)):
-        # logger(
-        # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
-        # )
-
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-
-        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-        for label, input in [
-            ("train", self.train_input[:32]),
-            ("test", self.test_input[:32]),
-        ]:
-            output = model(BracketedSequence(input)).x
-            output = output.log_softmax(dim=-1)
-            filename = os.path.join(
-                result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt"
-            )
-            with open(filename, "w") as f:
-                for n in range(input.size(0)):
-                    s = stack.seq_to_str(
-                        input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits
-                    )
-                    for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")):
-                        u = (
-                            " " * (10 - len(w))
-                            + w
-                            + " "
-                            + str(output[n][t][k].exp().item())
-                            + "\n"
-                        )
-                        f.write(u)
-                    f.write("\n")
-            logger(f"wrote {filename}")
-        #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-        for n in range(result.size(0)):
-            logger(
-                f"test_after  {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
-            )
-        ##############################################################
-
-
-######################################################################
-
-import rpl
-
-
-class RPL(Task):
-    def tensorize(self, sequences):
-        len_max = max([len(x) for x in sequences])
-        return torch.cat(
-            [
-                torch.tensor(
-                    [
-                        [
-                            self.token2id[str(c)]
-                            for c in s + ["<nul>"] * (len_max - len(s))
-                        ]
-                        for s in sequences
-                    ]
-                )
-            ],
-            0,
-        )
-
-    def seq2str(self, seq):
-        return " ".join([self.id2token[i] for i in seq])
-
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        nb_starting_values=3,
-        max_input=9,
-        prog_len=6,
-        nb_runs=5,
-        no_prog=False,
-        logger=None,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.device = device
-        self.no_prog = no_prog
-
-        train_sequences = [
-            rpl.generate(
-                nb_starting_values=nb_starting_values,
-                nb_result_values_max=4 * nb_starting_values,
-                max_input=max_input,
-                prog_len=prog_len,
-                nb_runs=nb_runs,
-            )
-            for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
-        ]
-
-        test_sequences = [
-            rpl.generate(
-                nb_starting_values=nb_starting_values,
-                nb_result_values_max=4 * nb_starting_values,
-                max_input=max_input,
-                prog_len=prog_len,
-                nb_runs=nb_runs,
-            )
-            for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
-        ]
-
-        symbols = list(
-            set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
-        )
-        val_max = max([x if type(x) is int else 0 for x in symbols])
-        symbols = list(filter(lambda x: type(x) is str, symbols))
-        symbols.sort()
-        symbols += [str(n) for n in range(val_max + 1)]
-        self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
-        self.id2token = dict([(n, c) for c, n in self.token2id.items()])
-
-        self.t_nul = self.token2id["<nul>"]
-        self.t_input = self.token2id["<in>"]
-        self.t_output = self.token2id["<out>"]
-        self.t_prog = self.token2id["<prg>"]
-        self.t_end = self.token2id["<end>"]
-
-        self.train_input = self.tensorize(train_sequences)
-        self.test_input = self.tensorize(test_sequences)
-
-        if no_prog:
-            # Excise the program from every train and test example
-            k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
-                None, :
-            ]
-            p = (
-                ((self.train_input == self.t_prog).long() * k)
-                .max(1, keepdim=True)
-                .values
-            )
-            self.train_input = (
-                self.train_input * (k <= p).long()
-                + self.t_end * (k == p + 1).long()
-                + self.t_nul * (k > p + 1).long()
-            )
-            k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
-                None, :
-            ]
-            p = (
-                ((self.test_input == self.t_prog).long() * k)
-                .max(1, keepdim=True)
-                .values
-            )
-            self.test_input = (
-                self.test_input * (k <= p).long()
-                + self.t_end * (k == p + 1).long()
-                + self.t_nul * (k > p + 1).long()
-            )
-
-        if logger is not None:
-            logger(f"value_max {val_max}")
-            for x in self.train_input[:25]:
-                end = (x != self.t_nul).nonzero().max().item() + 1
-                seq = [self.id2token[i.item()] for i in x[:end]]
-                s = " ".join(seq)
-                logger(f"example_seq {s}")
-
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
-            batch = batch[:, :last].to(self.device)
-            yield batch
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        # --------------------------------------------------------------------
-        def compute_nb_errors_prog(input, nb_to_log=0):
-            result = input.clone()
-            s = (result == self.t_prog).long()
-            ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
-            result = (1 - ar_mask) * result + ar_mask * self.t_nul
-
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis,
-                device=self.device,
-            )
-
-            sum_nb_total, sum_nb_errors = 0, 0
-            for one_input, one_result in zip(input, result):
-                seq = [self.id2token[i.item()] for i in one_result]
-                nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
-                sum_nb_total += 1
-                sum_nb_errors += 0 if nb_errors == 0 else 1
-                if nb_to_log > 0:
-                    gt_seq = [self.id2token[i.item()] for i in one_input]
-                    _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
-                    gt_prog = " ".join([str(x) for x in gt_prog])
-                    prog = " ".join([str(x) for x in prog])
-                    comment = "*" if nb_errors == 0 else "-"
-                    logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
-                    for start_stack, target_stack, result_stack, correct in stacks:
-                        comment = "*" if correct else "-"
-                        start_stack = " ".join([str(x) for x in start_stack])
-                        target_stack = " ".join([str(x) for x in target_stack])
-                        result_stack = " ".join([str(x) for x in result_stack])
-                        logger(
-                            f"  {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
-                        )
-                    nb_to_log -= 1
-
-            return sum_nb_total, sum_nb_errors
-
-        # --------------------------------------------------------------------
-        def compute_nb_errors_output(input, nb_to_log=0):
-            result = input.clone()
-            k = torch.arange(result.size(1), device=result.device)[None, :]
-            last_output_idx = (
-                ((result == self.t_output) * k).max(dim=1, keepdim=True).values
-            )
-            first_prog_idx = (
-                ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
-            )
-            ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
-            result = (1 - ar_mask) * result + ar_mask * self.t_nul
-
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis,
-                device=self.device,
-            )
-
-            sum_nb_total, sum_nb_errors = 0, 0
-            for one_input, one_result, i, j in zip(
-                input, result, last_output_idx, first_prog_idx
-            ):
-                seq = [self.id2token[i.item()] for i in one_result]
-                sum_nb_total += 1
-                correct = (one_input - one_result).abs().max() == 0
-                sum_nb_errors += 0 if correct else 1
-                if nb_to_log > 0:
-                    result_stack = [
-                        self.id2token[i.item()] for i in one_result[i : j + 1]
-                    ]
-                    target_stack = [
-                        self.id2token[i.item()] for i in one_input[i : j + 1]
-                    ]
-                    comment = "*" if correct else "-"
-                    result_stack = " ".join([str(x) for x in result_stack])
-                    target_stack = " ".join([str(x) for x in target_stack])
-                    logger(
-                        f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
-                    )
-                    nb_to_log -= 1
-
-            return sum_nb_total, sum_nb_errors
-
-        # --------------------------------------------------------------------
-
-        if not self.no_prog:
-            test_nb_total, test_nb_errors = compute_nb_errors_prog(
-                self.test_input[:1000].to(self.device), nb_to_log=10
-            )
-
-            logger(
-                f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
-            )
-
-            logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
-
-        test_nb_total, test_nb_errors = compute_nb_errors_output(
-            self.test_input[:1000].to(self.device), nb_to_log=10
-        )
-
-        logger(
-            f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
-        )
-
-        if save_attention_image is None:
-            logger("no save_attention_image (is pycairo installed?)")
-        else:
-            ns = torch.randint(self.test_input.size(0), (1,)).item()
-            input = self.test_input[ns : ns + 1].clone()
-            last = (input != self.t_nul).max(0).values.nonzero().max() + 3
-            input = input[:, :last].to(self.device)
-
-            with torch.autograd.no_grad():
-                t = model.training
-                model.eval()
-                model.record_attention(True)
-                model(BracketedSequence(input))
-                model.train(t)
-                ram = model.retrieve_attention()
-                model.record_attention(False)
-
-            tokens_output = [self.id2token[i.item()] for i in input[0]]
-            tokens_input = ["n/a"] + tokens_output[:-1]
-            for n_head in range(ram[0].size(1)):
-                filename = os.path.join(
-                    result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
-                )
-                attention_matrices = [m[0, n_head] for m in ram]
-                save_attention_image(
-                    filename,
-                    tokens_input,
-                    tokens_output,
-                    attention_matrices,
-                    k_top=10,
-                    # min_total_attention=0.9,
-                    token_gap=12,
-                    layer_gap=50,
-                )
-                logger(f"wrote {filename}")
-
-
-######################################################################
-
-
-import expr
-
-
-class Expr(Task):
-    def tensorize(self, sequences):
-        len_max = max([len(x) for x in sequences])
-        return torch.cat(
-            [
-                torch.tensor(
-                    [
-                        [self.char2id[c] for c in s + "#" * (len_max - len(s))]
-                        for s in sequences
-                    ]
-                )
-            ],
-            0,
-        ).to(self.device)
-
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        nb_variables,
-        sequence_length,
-        operand_max,
-        result_max,
-        batch_size,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.device = device
-
-        train_sequences = expr.generate_sequences(
-            nb_train_samples,
-            nb_variables=nb_variables,
-            length=sequence_length,
-            operand_max=operand_max,
-            result_max=result_max,
-        )
-
-        test_sequences = expr.generate_sequences(
-            nb_test_samples,
-            nb_variables=nb_variables,
-            length=sequence_length,
-            operand_max=operand_max,
-            result_max=result_max,
-        )
-
-        symbols = list(set("#" + "".join(train_sequences + test_sequences)))
-        symbols.sort()
-
-        self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
-        self.id2char = dict([(n, c) for c, n in self.char2id.items()])
-
-        self.filler, self.space = self.char2id["#"], self.char2id[" "]
-
-        self.train_input = self.tensorize(train_sequences)
-        self.test_input = self.tensorize(test_sequences)
-
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            last = (batch != self.filler).max(0).values.nonzero().max() + 3
-            batch = batch[:, :last]
-            yield batch
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def seq2str(self, s):
-        return "".join([self.id2char[k.item()] for k in s])
-
-    def produce_results(
-        self,
-        n_epoch,
-        model,
-        result_dir,
-        logger,
-        deterministic_synthesis,
-        input_file=None,
-    ):
-        def compute_nb_correct(input):
-            result = input.clone()
-            s = (result == self.space).long()
-            ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
-            result = (1 - ar_mask) * result + ar_mask * self.filler
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis,
-                device=self.device,
-            )
-
-            nb_total = input.size(0)
-            nb_correct = (input == result).long().min(1).values.sum()
-
-            #######################################################################
-            # Comput predicted vs. true variable values
-
-            nb_delta = torch.zeros(5, dtype=torch.int64)
-            nb_missed = 0
-
-            values_input = expr.extract_results([self.seq2str(s) for s in input])
-            values_result = expr.extract_results([self.seq2str(s) for s in result])
-
-            filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
-
-            with open(filename, "w") as f:
-                for i, r in zip(values_input, values_result):
-                    for n, vi in i.items():
-                        vr = r.get(n)
-                        f.write(f"{vi} {-1 if vr is None else vr}\n")
-
-                        if vr is None or vr < 0:
-                            nb_missed += 1
-                        else:
-                            d = abs(vr - vi)
-                            if d >= nb_delta.size(0):
-                                nb_missed += 1
-                            else:
-                                nb_delta[d] += 1
-
-            ######################################################################
-
-            return nb_total, nb_correct, nb_delta, nb_missed
-
-        (
-            test_nb_total,
-            test_nb_correct,
-            test_nb_delta,
-            test_nb_missed,
-        ) = compute_nb_correct(self.test_input[:10000])
-
-        logger(
-            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
-        )
-
-        logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
-        nb_total = test_nb_delta.sum() + test_nb_missed
-        for d in range(test_nb_delta.size(0)):
-            logger(
-                f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
-            )
-        logger(
-            f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
-        )
-
-        ##############################################################
-        # Log a few generated sequences
-        if input_file is None:
-            input = self.test_input[:10]
-        else:
-            with open(input_file, "r") as f:
-                sequences = [e.strip() for e in f.readlines()]
-                sequences = [s + " " + "#" * 50 for s in sequences]
-                input = self.tensorize(sequences)
-
-        result = input.clone()
-        s = (result == self.space).long()
-        ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
-        result = (1 - ar_mask) * result + ar_mask * self.filler
-
-        for n in range(result.size(0)):
-            logger(f"test_before {self.seq2str(result[n])}")
-
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-
-        correct = (1 - ar_mask) * self.space + ar_mask * input
-        for n in range(result.size(0)):
-            comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
-            logger(f"test_after  {self.seq2str(result[n])} {comment}")
-            logger(f"truth       {self.seq2str(correct[n])}")
-        ##############################################################
-
-
-######################################################################
-
-import grid
-
-
-class Grid(Task):
-    # Make a tensor from a list of strings
-    def str2tensor(self, descr):
-        token_descr = [s.strip().split(" ") for s in descr]
-        l = max([len(s) for s in token_descr])
-        token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
-        id_descr = [[self.token2id[u] for u in s] for s in token_descr]
-        return torch.tensor(id_descr, device=self.device)
-
-    # Make a list of strings from a tensor
-    def tensor2str(self, x):
-        return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
-
-    # trim all the tensors in the tuple z to remove as much token from
-    # left and right in the first tensor. If z is a tuple, all its
-    # elements are trimed according to the triming for the first
-    def trim(self, z, token="#"):
-        n = self.token2id[token]
-        if type(z) == tuple:
-            x = z[0]
-            i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
-            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
-            return tuple([t[:, a:b] for t in z])
-        else:
-            i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
-            a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
-            return z[:, a:b]
-
-    ######################
-
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        size,
-        fraction_play=0.0,
-        logger=None,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.device = device
-        self.batch_size = batch_size
-        self.grid_factory = grid.GridFactory(size=size)
-        self.fraction_play = fraction_play
-
-        if logger is not None:
-            logger(
-                f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
-            )
-
-        self.train_descr = self.grid_factory.generate_samples(
-            nb=nb_train_samples,
-            fraction_play=fraction_play,
-            progress_bar=lambda r: tqdm.tqdm(r),
-        )
-
-        self.test_descr = self.grid_factory.generate_samples(
-            nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
-        )
-
-        if fraction_play > 0:
-            self.play_descr = self.grid_factory.generate_samples(
-                nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
-            )
-        else:
-            self.play_descr = []
-
-        # Build the tokenizer
-        tokens = set()
-        for d in [self.train_descr, self.test_descr, self.play_descr]:
-            for s in d:
-                for t in s.strip().split(" "):
-                    tokens.add(t)
-        # make this set a sorted list to get the same tensors given
-        # the same descr
-        tokens = list(tokens)
-        tokens.sort()
-        tokens = ["#"] + tokens
-        self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
-        self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
-        self.t_nul = self.token2id["#"]
-        self.t_true = self.token2id["true"]
-        self.t_false = self.token2id["false"]
-        # self.t_pipe = self.token2id["|"]
-
-        # Tokenize the train and test sets
-        self.train_input = self.str2tensor(self.train_descr)
-        self.test_input = self.str2tensor(self.test_descr)
-        self.play_input = (
-            None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
-        )
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
-        ):
-            yield self.trim(batch)
-
-    def vocabulary_size(self):
-        return len(self.token2id)
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        correct = self.test_input[:1000]
-        result = correct.clone()
-        ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
-        result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
-
-        logger(f"----------------------------------------------------------")
-
-        for e in self.tensor2str(result[:10]):
-            logger(f"test_before {e}")
-
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-
-        logger(f"----------------------------------------------------------")
-
-        for e in self.tensor2str(result[:10]):
-            logger(f"test_after  {e}")
-
-        logger(f"----------------------------------------------------------")
-
-        nb_total = ar_mask.sum().item()
-        nb_correct = ((correct == result).long() * ar_mask).sum().item()
-
-        logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
-        logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
-
-        if self.play_input is not None:
-            result = self.play_input.clone()
-            ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
-            result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
-
-            logger(f"----------------------------------------------------------")
-
-            for e in self.tensor2str(result[:10]):
-                logger(f"play_before {e}")
-
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis,
-                device=self.device,
-            )
-
-            logger(f"----------------------------------------------------------")
-
-            for e in self.tensor2str(result[:10]):
-                logger(f"play_after  {e}")
-
-            logger(f"----------------------------------------------------------")
-
-
-######################################################################
-
-import qmlp
-
-
-class QMLP(Task):
-    ######################
-
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        result_dir,
-        logger=None,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.device = device
-        self.batch_size = batch_size
-        self.nb_samples_per_mlp = 256
-
-        if logger is not None:
-            logger(
-                f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
-            )
-
-        seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
-            nb_mlps=nb_train_samples + nb_test_samples,
-            nb_samples=self.nb_samples_per_mlp,
-            device=self.device,
-            batch_size=64,
-            nb_epochs=250,
-            nb_mlps_per_batch=1024,
-        )
-
-        self.train_input = seq[:nb_train_samples]
-        self.train_q_test_set = q_test_set[:nb_train_samples]
-        self.train_ref_test_errors = test_error[:nb_train_samples]
-        self.test_input = seq[nb_train_samples:]
-        self.test_q_test_set = q_test_set[nb_train_samples:]
-        self.test_ref_test_errors = test_error[nb_train_samples:]
-
-        filename = os.path.join(result_dir, f"train_errors_ref.dat")
-        with open(filename, "w") as f:
-            for e in self.train_ref_test_errors:
-                f.write(f"{e}\n")
-
-        filename = os.path.join(result_dir, f"test_errors_ref.dat")
-        with open(filename, "w") as f:
-            for e in self.test_ref_test_errors:
-                f.write(f"{e}\n")
-
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
-        ):
-            yield batch
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        correct = self.test_input[:1000]
-        result = correct.clone()
-        ar_mask = (
-            torch.arange(result.size(1), device=result.device)
-            > self.nb_samples_per_mlp * 3 + 1
-        ).long()[None, :]
-        ar_mask = ar_mask.expand_as(result)
-        result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
-
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-
-        q_train_set = result[:, : self.nb_samples_per_mlp * 3]
-        q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
-        error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
-
-        filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
-        with open(filename, "w") as f:
-            for e in error_test:
-                f.write(f"{e}\n")
-
-
-######################################################################
-
-import greed
-
-
-class Greed(Task):
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        height,
-        width,
-        T,
-        nb_walls,
-        nb_coins,
-        logger=None,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.device = device
-
-        self.world = greed.GreedWorld(height, width, T, nb_walls, nb_coins)
-
-        states, actions, rewards = self.world.generate_episodes(
-            nb_train_samples + nb_test_samples
-        )
-        seq = self.world.episodes2seq(states, actions, rewards)
-        self.train_input = seq[:nb_train_samples].to(self.device)
-        self.test_input = seq[nb_train_samples:].to(self.device)
-
-    def wipe_lookahead_rewards(self, batch):
-        t = torch.arange(batch.size(1), device=batch.device)[None, :]
-        u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
-        lr_mask = (t <= u).long() * (
-            t % self.world.it_len == self.world.index_lookahead_reward
-        ).long()
-
-        return (
-            lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
-            + (1 - lr_mask) * batch
-        )
-
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        assert split in {"train", "test"}
-        input = self.train_input if split == "train" else self.test_input
-        if nb_to_use > 0:
-            input = input[:nb_to_use]
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield self.wipe_lookahead_rewards(batch)
-
-    def vocabulary_size(self):
-        return self.world.nb_codes
-
-    def thinking_autoregression(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
-    ):
-        snapshots = []
-
-        def ar(result, ar_mask, logit_biases=None):
-            ar_mask = ar_mask.expand_as(result)
-            result *= 1 - ar_mask
-            masked_inplace_autoregression(
-                model,
-                self.batch_size,
-                result,
-                ar_mask,
-                deterministic_synthesis=deterministic_synthesis,
-                logit_biases=logit_biases,
-                device=self.device,
-                progress_bar_desc=None,
-            )
-            warnings.warn("keeping thinking snapshots", RuntimeWarning)
-            snapshots.append(result[:100].detach().clone())
-
-        # Generate iteration after iteration
-
-        result = self.test_input[:250].clone()
-        # Erase all the content but that of the first iteration
-        result[:, self.world.it_len :] = -1
-        # Set the lookahead_reward of the firs to UNKNOWN
-        result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
-            greed.REWARD_UNKNOWN
-        )
-
-        t = torch.arange(result.size(1), device=result.device)[None, :]
-
-        for u in tqdm.tqdm(
-            range(0, result.size(1), self.world.it_len),
-            desc="thinking",
-        ):
-            # Generate the next state but keep the initial one, the
-            # lookahead_reward of previous iterations are set to
-            # UNKNOWN
-            if u > 0:
-                result[
-                    :, u + self.world.index_lookahead_reward
-                ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
-                ar_mask = (t >= u + self.world.index_states).long() * (
-                    t < u + self.world.index_states + self.world.state_len
-                ).long()
-                ar(result, ar_mask)
-
-            # Generate the action and reward with lookahead_reward to +1
-            result[
-                :, u + self.world.index_lookahead_reward
-            ] = self.world.lookahead_reward2code(greed.REWARD_PLUS)
-            ar_mask = (t >= u + self.world.index_reward).long() * (
-                t <= u + self.world.index_action
-            ).long()
-            ar(result, ar_mask)
-
-            # Set the lookahead_reward to UNKNOWN for the next iterations
-            result[
-                :, u + self.world.index_lookahead_reward
-            ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
-
-        filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
-        with open(filename, "w") as f:
-            for n in range(snapshots[0].size(0)):
-                for s in snapshots:
-                    lr, s, a, r = self.world.seq2episodes(
-                        s[n : n + 1],
-                    )
-                    str = self.world.episodes2str(
-                        lr, s, a, r, unicode=True, ansi_colors=True
-                    )
-                    f.write(str)
-                f.write("\n\n")
-
-        # Saving the generated sequences
-
-        lr, s, a, r = self.world.seq2episodes(result)
-        str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
-
-        filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt")
-        with open(filename, "w") as f:
-            f.write(str)
-            logger(f"wrote {filename}")
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
-    ):
-        result = self.wipe_lookahead_rewards(self.test_input[:250].clone())
-
-        # Saving the ground truth
-
-        lr, s, a, r = self.world.seq2episodes(
-            result,
-        )
-        str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
-
-        filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt")
-        with open(filename, "w") as f:
-            f.write(str)
-            logger(f"wrote {filename}")
-
-        # Re-generating from the first frame
-
-        ar_mask = (
-            torch.arange(result.size(1), device=result.device) >= self.world.it_len
-        ).long()[None, :]
-        ar_mask = ar_mask.expand_as(result)
-        result *= 1 - ar_mask  # paraaaaanoiaaaaaaa
-
-        masked_inplace_autoregression(
-            model,
-            self.batch_size,
-            result,
-            ar_mask,
-            deterministic_synthesis,
-            device=self.device,
-        )
-
-        # Saving the generated sequences
-
-        lr, s, a, r = self.world.seq2episodes(
-            result,
-        )
-        str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
-
-        filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt")
-        with open(filename, "w") as f:
-            f.write(str)
-            logger(f"wrote {filename}")
-
-        self.thinking_autoregression(
-            n_epoch, model, result_dir, logger, deterministic_synthesis, nmax
-        )
-
-
-######################################################################
 ######################################################################
 
 import world
diff --git a/turing.py b/turing.py
deleted file mode 100755 (executable)
index 2bcdeeb..0000000
--- a/turing.py
+++ /dev/null
@@ -1,46 +0,0 @@
-#!/usr/bin/env python
-
-import torch
-
-
-def generate_turing_sequences(N, nb_iter=5, nb_states=3, nb_symbols=4, tape_size=5):
-    next_state = torch.randint(nb_states, (N, nb_states, nb_symbols))
-    next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols))
-    next_move = torch.randint(3, (N, nb_states, nb_symbols))
-
-    all_n = torch.arange(N)
-
-    tape = torch.randint(nb_symbols, (N, tape_size))
-    # position = torch.randint(tape_size, (N,))
-    # state = torch.randint(nb_states, (N,))
-    position = torch.zeros(N, dtype=torch.int64)
-    state = torch.zeros(N, dtype=torch.int64)
-
-    result = []
-
-    for _ in range(nb_iter):
-        result.append(tape.clone())
-        current_symbol = tape[all_n, position]
-        tape[all_n, position] = next_symbol[all_n, state, current_symbol]
-        position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size
-        state = next_state[all_n, state, current_symbol]
-
-    result = torch.cat([x[:, None, :] for x in result], dim=1)
-
-    return result
-
-
-######################################################################
-
-if __name__ == "__main__":
-    print("Basic check.")
-
-    tapes = generate_turing_sequences(1, nb_iter=10)
-
-    for i in range(tapes.size(1)):
-        # print(f"- {i:03d} ------------------------")
-        # for s, h, r in zip(state, position, tape):
-        # print("".join([f"{x}" for x in r]))
-        # print(" " * h + f"^[{s}]")
-        for r in tapes:
-            print("".join([f"{x}" for x in r[i]]))