From 9b9d7bc878171bd65b0c8a803494a2e4ef00c5fe Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 22 Jul 2023 19:05:38 +0200 Subject: [PATCH] Update. --- graph.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/graph.py b/graph.py index 5bab861..5195cc9 100755 --- a/graph.py +++ b/graph.py @@ -47,12 +47,6 @@ def save_attention_image( x, y = 0, 0 for d in range(attention.size(0)): - if d > 0: - for n in range(attention.size(-1)): - xc, yc = n * token_gap, -d * layer_gap - ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) - ctx.fill() - at = attention[d] ni = torch.arange(at.size(0))[:, None].expand_as(at) nj = torch.arange(at.size(1))[None, :].expand_as(at) @@ -71,14 +65,14 @@ def save_attention_image( ctx.stroke() y -= layer_gap - for d in range(1, attention.size(0)): + for d in range(0, attention.size(0) + 1): for n in range(attention.size(-1)): xc, yc = n * token_gap, -d * layer_gap ctx.set_source_rgb(1.0, 1.0, 1.0) - ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi) + ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) ctx.fill() ctx.set_source_rgb(0.0, 0.0, 0.0) - ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi) + ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi) ctx.fill() ctx.set_source_rgb(0.0, 0.0, 0.0) @@ -93,7 +87,7 @@ def save_attention_image( x_advance, y_advance, ) = ctx.text_extents(s) - ctx.move_to(k * token_gap - width_t / 2, -y_bearing) + ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing) ctx.show_text(s) for k, t in enumerate(tokens_output): @@ -106,7 +100,9 @@ def save_attention_image( x_advance, y_advance, ) = ctx.text_extents(s) - ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap) + ctx.move_to( + k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap + ) ctx.show_text(s) x, y, width, height = surface.ink_extents() -- 2.39.5