From: Francois Fleuret Date: Wed, 11 Jan 2017 07:02:04 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=be03a73e411d18082a2dd99bff5df45c085017ca;p=dagnn.git Update. --- diff --git a/dagnn.lua b/dagnn.lua index 1ec9b4e..52913ad 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -1,21 +1,16 @@ -#!/usr/bin/env luajit require 'torch' require 'nn' -require 'image' -require 'optim' ----------------------------------------------------------------------- +local DAG, parent = torch.class('nn.DAG', 'nn.Container') -local Graph, parent = torch.class('nn.Graph', 'nn.Container') - -function Graph:__init() +function DAG:__init() parent.__init(self) self.pred = {} self.succ = {} end -function Graph:addEdge(a, b) +function DAG:addEdge(a, b) local pred, succ = self.pred, self.succ if not pred[a] and not succ[a] then self:add(a) @@ -29,7 +24,7 @@ function Graph:addEdge(a, b) succ[a][#succ[a] + 1] = b end -function Graph:setInput(i) +function DAG:setInput(i) if torch.type(i) == 'table' then self.inputModules = i for _, m in ipairs(i) do @@ -42,7 +37,7 @@ function Graph:setInput(i) end end -function Graph:setOutput(o) +function DAG:setOutput(o) if torch.type(o) == 'table' then self.outputModules = o for _, m in ipairs(o) do @@ -55,7 +50,7 @@ function Graph:setOutput(o) end end -function Graph:order() +function DAG:order() local distance = {} for _, a in pairs(self.inputModules) do @@ -85,13 +80,13 @@ function Graph:order() for i, a in ipairs(self.sorted) do self.sorted[i] = a[2] end end -function Graph:print() +function DAG:print() for i, d in ipairs(self.sorted) do print('#' .. i .. ' -> ' .. torch.type(d)) end end -function Graph:updateOutput(input) +function DAG:updateOutput(input) if #self.inputModules == 1 then self.inputModules[1]:updateOutput(input) else @@ -125,43 +120,3 @@ function Graph:updateOutput(input) return self.output end - ----------------------------------------------------------------------- - -a = nn.Linear(10, 10) -b = nn.ReLU() -c = nn.Linear(10, 3) -d = nn.Linear(10, 3) -e = nn.CMulTable() -f = nn.Linear(3, 2) - ---[[ - - a -----> b ---> c ----> e --- - \ / - \--> d ---/ - \ - \---> f --- -]]-- - -g = Graph:new() - -g:setInput(a) -g:setOutput({ e, f }) -g:addEdge(c, e) -g:addEdge(a, b) -g:addEdge(d, e) -g:addEdge(b, c) -g:addEdge(b, d) -g:addEdge(d, f) - -g:order() - -g:print(graph) - -input = torch.Tensor(3, 10):uniform() - -output = g:updateOutput(input) - -print(output[1]) -print(output[2]) diff --git a/test-dagnn.lua b/test-dagnn.lua new file mode 100755 index 0000000..a0a81ab --- /dev/null +++ b/test-dagnn.lua @@ -0,0 +1,44 @@ +#!/usr/bin/env luajit + +require 'torch' +require 'nn' + +require 'dagnn' + +a = nn.Linear(10, 10) +b = nn.ReLU() +c = nn.Linear(10, 3) +d = nn.Linear(10, 3) +e = nn.CMulTable() +f = nn.Linear(3, 2) + +--[[ + + a -----> b ---> c ----> e --- + \ / + \--> d ---/ + \ + \---> f --- +]]-- + +g = DAG:new() + +g:setInput(a) +g:setOutput({ e, f }) +g:addEdge(c, e) +g:addEdge(a, b) +g:addEdge(d, e) +g:addEdge(b, c) +g:addEdge(b, d) +g:addEdge(d, f) + +g:order() + +g:print(graph) + +input = torch.Tensor(3, 10):uniform() + +output = g:updateOutput(input) + +print(output[1]) +print(output[2])