x, y = 0, 0
 
     for d in range(attention.size(0)):
-        if d > 0:
-            for n in range(attention.size(-1)):
-                xc, yc = n * token_gap, -d * layer_gap
-                ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
-                ctx.fill()
-
         at = attention[d]
         ni = torch.arange(at.size(0))[:, None].expand_as(at)
         nj = torch.arange(at.size(1))[None, :].expand_as(at)
                 ctx.stroke()
         y -= layer_gap
 
-    for d in range(1, attention.size(0)):
+    for d in range(0, attention.size(0) + 1):
         for n in range(attention.size(-1)):
             xc, yc = n * token_gap, -d * layer_gap
             ctx.set_source_rgb(1.0, 1.0, 1.0)
-            ctx.arc(xc, yc, token_gap / 10 + 0.5, 0, 2 * math.pi)
+            ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
             ctx.fill()
             ctx.set_source_rgb(0.0, 0.0, 0.0)
-            ctx.arc(xc, yc, token_gap / 10, 0, 2 * math.pi)
+            ctx.arc(xc, yc, token_gap / 20, 0, 2 * math.pi)
             ctx.fill()
 
     ctx.set_source_rgb(0.0, 0.0, 0.0)
             x_advance,
             y_advance,
         ) = ctx.text_extents(s)
-        ctx.move_to(k * token_gap - width_t / 2, -y_bearing)
+        ctx.move_to(k * token_gap - width_t / 2, token_gap / 5 - y_bearing)
         ctx.show_text(s)
 
     for k, t in enumerate(tokens_output):
             x_advance,
             y_advance,
         ) = ctx.text_extents(s)
-        ctx.move_to(k * token_gap - width_t / 2, -attention.size(0) * layer_gap)
+        ctx.move_to(
+            k * token_gap - width_t / 2, -token_gap / 5 - attention.size(0) * layer_gap
+        )
         ctx.show_text(s)
 
     x, y, width, height = surface.ink_extents()