From 3b9ba21fd3d06a20703216cc0a77fe9dc78b079f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jul 2023 22:52:54 +0200 Subject: [PATCH] Update. --- graph.py | 52 +++++++++++++++++++++++++++++----------------------- tasks.py | 19 +++++++++---------- 2 files changed, 38 insertions(+), 33 deletions(-) diff --git a/graph.py b/graph.py index a2554d2..a819283 100755 --- a/graph.py +++ b/graph.py @@ -18,9 +18,7 @@ def save_attention_image( tokens_input, tokens_output, # An iterable set of BxHxTxT attention matrices - attention_arrays, - n_sample=0, - n_head=0, + attention_matrices, pixel_scale=8, token_gap=15, layer_gap=25, @@ -35,20 +33,19 @@ def save_attention_image( k_top=None, curved=True, ): - attention = torch.cat( - [x[n_sample : n_sample + 1, n_head] for x in attention_arrays], dim=0 - ) - if k_top is not None: - attention = attention * ( - attention.sort(dim=-1, descending=True).indices < k_top - ) + am = [] + for m in attention_matrices: + am.append(m * (m.sort(dim=-1, descending=True).indices < k_top)) + attention_matrices = am if min_total_attention is not None: - s = attention.sort(dim=-1) - m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long() - b = m.new(attention.size()).scatter_(dim=-1, index=s.indices, src=m) - attention = attention * b + am = [] + for m in attention_matrices: + s = m.sort(dim=-1) + m = 1 - (s.values.cumsum(-1) < 1 - min_total_attention).long() + b = m.new(m.size()).scatter_(dim=-1, index=s.indices, src=m) + am.append(m * b) surface = cairo.RecordingSurface(cairo.CONTENT_COLOR_ALPHA, None) @@ -61,8 +58,9 @@ def save_attention_image( x, y = 0, 0 - for d in range(attention.size(0)): - at = attention[d] + ctx.set_line_width(0.25) + for d in range(len(attention_matrices)): + at = attention_matrices[d] ni = torch.arange(at.size(0))[:, None].expand_as(at) nj = torch.arange(at.size(1))[None, :].expand_as(at) at = at.flatten() @@ -74,7 +72,6 @@ def save_attention_image( if a > 0 and a >= min_link_attention: c = 1 - a.item() ctx.set_source_rgb(c, c, c) - ctx.set_line_width(0.5) ax, ay = j * token_gap, y - y_eps ctx.move_to(ax, ay) dx, dy = i * token_gap, y - layer_gap + y_eps @@ -87,8 +84,13 @@ def save_attention_image( ctx.stroke() y -= layer_gap - for d in range(0, attention.size(0) + 1): - for n in range(attention.size(-1)): + for d in range(0, len(attention_matrices) + 1): + n = ( + attention_matrices[0].size(-1) + if d == 0 + else attention_matrices[d - 1].size(-2) + ) + for n in range(n): xc, yc = n * token_gap, -d * layer_gap ctx.set_source_rgb(1.0, 1.0, 1.0) ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) @@ -123,7 +125,8 @@ def save_attention_image( y_advance, ) = ctx.text_extents(s) ctx.move_to( - k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap + k * token_gap - width_t / 2, + -token_gap / 5 - len(attention_matrices) * layer_gap, ) ctx.show_text(s) @@ -156,7 +159,7 @@ if __name__ == "__main__": dim_keys=2, dim_hidden=2, nb_heads=2, - nb_blocks=3, + nb_blocks=5, dropout=0.1, causal=True, ) @@ -166,13 +169,16 @@ if __name__ == "__main__": y1 = model(mygpt.BracketedSequence(x)).x - attention = model.retrieve_attention() + attention_matrices = [m[0, 0] for m in model.retrieve_attention()] + + # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ] + # for a in attention_matrices: a=a/a.sum(-1,keepdim=True) save_attention_image( "attention.pdf", tokens_input, tokens_output, - attention, + attention_matrices, # k_top=2, min_total_attention=0.9, ) diff --git a/tasks.py b/tasks.py index 0eed2aa..0c92af9 100755 --- a/tasks.py +++ b/tasks.py @@ -1284,7 +1284,7 @@ class RPL(Task): ) if save_attention_image is not None: - input = self.test_input[:10] + input = self.test_input[:1] result = input.clone() s = (result == self.t_prog).long() ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) @@ -1305,24 +1305,23 @@ class RPL(Task): model.record_attention(True) model(BracketedSequence(result)) model.train(t) - attention = model.retrieve_attention() + ram = model.retrieve_attention() model.record_attention(False) - n_sample = 0 - tokens_output = [self.id2token[i.item()] for i in result[n_sample]] + tokens_output = [self.id2token[i.item()] for i in result[0]] tokens_input = ["n/a"] + tokens_output[:-1] - for n_head in range(attention[0].size(1)): + for n_head in range(ram[0].size(1)): filename = f"rpl_attention_{n_epoch}_h{n_head}.pdf" + attention_matrices = [m[0, n_head] for m in ram] save_attention_image( filename, tokens_input, tokens_output, - attention, - n_sample=n_sample, - n_head=n_head, + attention_matrices, token_gap=12, - layer_gap=40, - # k_top=2, + layer_gap=50, + k_top=10, + # min_total_attention=0.9, ) logger(f"wrote {filename}") -- 2.39.5