From: François Fleuret Date: Sat, 1 Jul 2023 16:53:37 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=16bf88f88bbab138c0dc33b4fbd2d88cf9db3ae5;p=picoclvr.git Update. --- diff --git a/stack.py b/stack.py index d3be4f8..dc494bb 100755 --- a/stack.py +++ b/stack.py @@ -13,45 +13,44 @@ import torch, torchvision # CODE_VAL=val + 2 * nb_stacks -def generate(nb, seq_len, nb_stacks, nb_values): - stack = torch.empty(nb, nb_stacks, seq_len, dtype=torch.int64) +def generate(nb, nb_steps, nb_stacks, nb_values): + stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64) stack_pointers = torch.zeros(nb, nb_stacks, dtype=torch.int64) k = torch.arange(nb) - result = torch.empty(nb, 2 * seq_len, dtype=torch.int64) + result = torch.empty(nb, 2 * nb_steps, dtype=torch.int64) + depth_counts = torch.zeros(nb, 2 * nb_steps, dtype=torch.int64) - for t in range(seq_len): + for t in range(nb_steps): op = torch.randint(2, (nb,)) st = torch.randint(nb_stacks, (nb,)) op = op * (stack_pointers[k, st] > 0) val_push = torch.randint(nb_values, (nb,)) - # top_val[n,s]=stack[n,stack_pointers[n,s]] - top_values = stack[ + val_pop = stack[ k, st, (stack_pointers[k, st] - 1).clamp(min=0), ] - stack[ - k[:, None].expand_as(stack_pointers), - st[:, None].expand_as(stack_pointers), - stack_pointers, - ] = val_push[:, None].expand_as(stack_pointers) + stack[k, st, stack_pointers[k, st]] = val_push + depth_counts[:, 2 * t + 1] = stack_pointers[k, st] stack_pointers[k[op == 0], st[op == 0]] += 1 stack_pointers[k[op == 1], st[op == 1]] -= 1 result[:, 2 * t] = st * 2 + op - result[:, 2 * t + 1] = (op * top_values + (1 - op) * val_push) + 2 * nb_stacks + result[:, 2 * t + 1] = (op * val_pop + (1 - op) * val_push) + 2 * nb_stacks - return result + return result, depth_counts -def seq_to_str(seq): +def seq_to_str(seq, depth_counts=None): assert seq.size(0) % 2 == 0 s = "" - for t in range(0, seq.size(0), 2): - op = seq[t] + for t in range(seq.size(0) // 2): + op = seq[2 * t] op = f"POP_{op//2}" if op % 2 == 1 else f"PUSH_{op//2}" - val = seq[t + 1] - 2 * nb_stacks + val = seq[2 * t + 1] - 2 * nb_stacks if t > 0: s += " " + if depth_counts is not None: + s += f"[{depth_counts[2*t+1]}] " s += f"{op} {val}" return s @@ -59,7 +58,10 @@ def seq_to_str(seq): ###################################################################### if __name__ == "__main__": - nb, seq_len, nb_stacks, nb_values = 3, 10, 1, 5 - result = generate(nb=nb, seq_len=seq_len, nb_stacks=nb_stacks, nb_values=nb_values) - for n in range(result.size(0)): - print(seq_to_str(result[n])) + nb, nb_steps, nb_stacks, nb_values = 150000, 10, 1, 5 + seq, depth_counts = generate( + nb=nb, nb_steps=nb_steps, nb_stacks=nb_stacks, nb_values=nb_values + ) + + for n in range(min(10, seq.size(0))): + print(seq_to_str(seq[n], depth_counts[n]))