From 95717a8bf88159051f9c4b8862b0b643187826e9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jul 2023 23:13:29 +0200 Subject: [PATCH] Update. --- graph.py | 19 ++++++++----------- tasks.py | 9 +++++---- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/graph.py b/graph.py index c286388..08f1170 100755 --- a/graph.py +++ b/graph.py @@ -14,16 +14,10 @@ import cairo def save_attention_image( - filename, + filename, # image to save tokens_input, tokens_output, - # An iterable set of BxHxTxT attention matrices - attention_matrices, - pixel_scale=8, - token_gap=15, - layer_gap=25, - y_eps=0.5, - padding=10, + attention_matrices, # list of 2d tensors T1xT2, T2xT3, ..., Tk-1xTk # do not draw links with a lesser attention min_link_attention=0, # draw only the strongest links necessary to reache @@ -32,6 +26,11 @@ def save_attention_image( # draw only the top k links k_top=None, curved=True, + pixel_scale=8, + token_gap=15, + layer_gap=25, + y_eps=0.5, + padding=10, ): if k_top is not None: am = [] @@ -161,7 +160,7 @@ if __name__ == "__main__": nb_heads=2, nb_blocks=5, dropout=0.1, - #causal=True, + causal=True, ) model.eval() @@ -171,8 +170,6 @@ if __name__ == "__main__": attention_matrices = [m[0, 0] for m in model.retrieve_attention()] - - # attention_matrices = [ torch.rand(3,5), torch.rand(8,3), torch.rand(5,8) ] # for a in attention_matrices: a=a/a.sum(-1,keepdim=True) diff --git a/tasks.py b/tasks.py index 0c92af9..a27b836 100755 --- a/tasks.py +++ b/tasks.py @@ -140,7 +140,6 @@ class ProblemLevel2(Problem): num_classes=self.len_source, ) source1 = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source] - # source1 = torch.randint(10, (nb, self.len_source)) marker1 = torch.full((nb, 1), 10) result1 = operators.bmm(source1[:, :, None]).squeeze(-1) marker2 = torch.full((nb, 1), 11) @@ -1311,17 +1310,19 @@ class RPL(Task): tokens_output = [self.id2token[i.item()] for i in result[0]] tokens_input = ["n/a"] + tokens_output[:-1] for n_head in range(ram[0].size(1)): - filename = f"rpl_attention_{n_epoch}_h{n_head}.pdf" + filename = os.path.join( + result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf" + ) attention_matrices = [m[0, n_head] for m in ram] save_attention_image( filename, tokens_input, tokens_output, attention_matrices, - token_gap=12, - layer_gap=50, k_top=10, # min_total_attention=0.9, + token_gap=12, + layer_gap=50, ) logger(f"wrote {filename}") -- 2.39.5