From: Francois Fleuret Date: Wed, 8 Mar 2017 10:39:09 +0000 (+0100) Subject: Initial commit. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=63f04303f0320d25d36e6a4f9f535e62cdb139e1;p=agtree2dot.git Initial commit. --- 63f04303f0320d25d36e6a4f9f535e62cdb139e1 diff --git a/README.md b/README.md new file mode 100644 index 0000000..4c219b7 --- /dev/null +++ b/README.md @@ -0,0 +1,59 @@ +# Introduction # + +This package provides a function that generates a dot file from the +auto-grad graph. + +# Usage # + +## Functions ## + +### agtree2dot.save_dot(variable, variable_labels, result_file) ### + +Saves into `result_file` a dot file corresponding to the auto-grad graph for `variable`, which can be either a single `Variable` or a set of `Variable`s. The dictionary `variable_labels` associates strings to some variables, which will be used in the resulting graph. + +## Example ## + +A typical use would be: + +```python +import torch + +from torch import nn +from torch.nn import functional as fn +from torch import Tensor +from torch.autograd import Variable +from torch.nn import Module + +import agtree2dot + +class MLP(Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super(MLP, self).__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.fc1(x) + x = fn.tanh(x) + x = self.fc2(x) + return x + +mlp = MLP(10, 20, 1) +input = Variable(Tensor(100, 10).normal_()) +target = Variable(Tensor(100).normal_()) +output = mlp(input) +criterion = nn.MSELoss() +loss = criterion(output, target) + +agtree2dot.save_dot(loss, + { input: 'input', loss: 'loss' }, + open('./mlp.dot', 'w')) +``` + +which would generate a file mlp.dot, which can then be translated to pdf with + +``` +dot mlp.dot -Lg -T pdf -o mlp.pdf +``` + +to produce [mlp.pdf](https://fleuret.org/git-extract/agtree2dot/mlp.pdf). diff --git a/agtree2dot.py b/agtree2dot.py new file mode 100755 index 0000000..f215f94 --- /dev/null +++ b/agtree2dot.py @@ -0,0 +1,108 @@ + +######################################################################### +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the version 3 of the GNU General Public License # +# as published by the Free Software Foundation. # +# # +# This program is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +# Written by and Copyright (C) Francois Fleuret # +# Contact for comments & bug reports # +######################################################################### + +import torch +import re +import sys + +import torch.autograd + +###################################################################### + +def save_dot_rec(x, node_labels = {}, out = sys.stdout, drawn_node_id = {}): + + if isinstance(x, set): + + for y in x: + save_dot_rec(y, node_labels, out, drawn_node_id) + + else: + + if not x in drawn_node_id: + drawn_node_id[x] = len(drawn_node_id) + 1 + + # Draw the node (Variable or Function) if not already + # drawn + + if isinstance(x, torch.autograd.Variable): + name = ((x in node_labels and node_labels[x]) or 'Variable') + # Add the tensor size + name = name + ' [' + for d in range(0, x.data.dim()): + if d > 0: name = name + ', ' + name = name + str(x.data.size(d)) + name = name + ']' + + out.write(' ' + str(drawn_node_id[x]) + + ' [shape=record,penwidth=1,style=rounded,label="' + name + '"]\n') + + if hasattr(x, 'creator') and x.creator: + y = x.creator + save_dot_rec(y, node_labels, out, drawn_node_id) + # Edge to the creator + out.write(' ' + str(drawn_node_id[y]) + ' -> ' + str(drawn_node_id[x]) + '\n') + + elif isinstance(x, torch.autograd.Function): + name = ((x in node_labels and (node_labels[x] + ': ')) or '') + \ + re.search('<.*\.([a-zA-Z0-9_]*)\'>', str(type(x))).group(1) + + prefix = '' + suffix = '' + + if hasattr(x, 'num_inputs') and x.num_inputs > 1: + prefix = '{ ' + for i in range(0, x.num_inputs): + if i > 0: prefix = prefix + ' | ' + prefix = prefix + ' ' + str(i) + prefix = prefix + ' } | ' + + if hasattr(x, 'num_outputs') and x.num_outputs > 1: + suffix = ' | { ' + for i in range(0, x.num_outputs): + if i > 0: suffix = suffix + ' | ' + suffix = suffix + ' ' + str(i) + suffix = suffix + ' }' + + out.write(' ' + str(drawn_node_id[x]) + \ + ' [shape=record,label="{ ' + prefix + name + suffix + ' }"]\n') + + else: + + print('Cannot handle ' + str(type(x)) + ' (only Variables and Functions).') + exit(1) + + if hasattr(x, 'num_inputs'): + for i in range(0, x.num_inputs): + y = x.previous_functions[i][0] + save_dot_rec(y, node_labels, out, drawn_node_id) + from_str = str(drawn_node_id[y]) + if hasattr(y, 'num_outputs') and y.num_outputs > 1: + from_str = from_str + ':output' + str(x.previous_functions[i][1]) + to_str = str(drawn_node_id[x]) + if x.num_inputs > 1: + to_str = to_str + ':input' + str(i) + out.write(' ' + from_str + ' -> ' + to_str + '\n') + +###################################################################### + +def save_dot(x, node_labels = {}, out = sys.stdout): + out.write('digraph {\n') + save_dot_rec(x, node_labels, out, {}) + out.write('}\n') + +###################################################################### diff --git a/mlp.py b/mlp.py new file mode 100755 index 0000000..8497848 --- /dev/null +++ b/mlp.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +######################################################################### +# This program is free software: you can redistribute it and/or modify # +# it under the terms of the version 3 of the GNU General Public License # +# as published by the Free Software Foundation. # +# # +# This program is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of # +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # +# General Public License for more details. # +# # +# You should have received a copy of the GNU General Public License # +# along with this program. If not, see . # +# # +# Written by and Copyright (C) Francois Fleuret # +# Contact for comments & bug reports # +######################################################################### + +from torch import nn +from torch.nn import functional as fn +from torch import Tensor +from torch.autograd import Variable +from torch.nn import Module + +import agtree2dot + +class MLP(Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super(MLP, self).__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.fc1(x) + x = fn.tanh(x) + x = self.fc2(x) + return x + +mlp = MLP(10, 20, 1) +input = Variable(Tensor(100, 10).normal_()) +target = Variable(Tensor(100).normal_()) +output = mlp(input) +criterion = nn.MSELoss() +loss = criterion(output, target) + +agtree2dot.save_dot(loss, + { input: 'input', target: 'target', loss: 'loss' }, + open('./mlp.dot', 'w')) + +print('Generated mlp.dot. You can convert it to pdf with') +print('> dot mlp.dot -Lg -T pdf -o mlp.pdf')