From 682b76200f755f5f16477e086056a86cafdea1cd Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Wed, 11 Jan 2017 08:54:07 +0100 Subject: [PATCH] Update. The input/output can now be nested tables. --- dagnn.lua | 53 +++++++++++++++----------------------------------- test-dagnn.lua | 21 ++++++++++++++++---- 2 files changed, 33 insertions(+), 41 deletions(-) diff --git a/dagnn.lua b/dagnn.lua index 4841843..65a30e2 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -25,32 +25,26 @@ function DAG:addEdge(a, b) succ[a][#succ[a] + 1] = b end -function DAG:setInput(i) - self.sorted = nil - if torch.type(i) == 'table' then - self.inputModules = i - for _, m in ipairs(i) do - if not self.pred[m] and not self.succ[m] then - self:add(m) - end +function DAG:applyOnModules(f, t1, t2) + if torch.type(t1) == 'table' then + local result = {} + for k, s in pairs(t1) do + result[k] = self:applyOnModules(f, s, t2 and t2[k]) end + return result else - self:setInput({ i }) + return f(t1, t2) end end +function DAG:setInput(i) + self.sorted = nil + self.inputModules = i +end + function DAG:setOutput(o) self.sorted = nil - if torch.type(o) == 'table' then - self.outputModules = o - for _, m in ipairs(o) do - if not self.pred[m] and not self.succ[m] then - self:add(m) - end - end - else - self:setOutput({ o }) - end + self.outputModules = o end function DAG:sort() @@ -60,9 +54,7 @@ function DAG:sort() local distance = {} - for _, a in pairs(self.inputModules) do - distance[a] = 1 - end + self:applyOnModules(function(m) distance[m] = 1 end, self.inputModules) local nc @@ -98,13 +90,7 @@ end function DAG:updateOutput(input) self:sort() - if #self.inputModules == 1 then - self.inputModules[1]:updateOutput(input) - else - for i, d in ipairs(self.inputModules) do - d:updateOutput(input[i]) - end - end + self:applyOnModules(function(m, i) m:updateOutput(i) end, self.inputModules, input) for _, d in ipairs(self.sorted) do if self.pred[d] then @@ -120,14 +106,7 @@ function DAG:updateOutput(input) end end - if #self.outputModules == 1 then - self.output = self.outputModules[1].output - else - self.output = { } - for i, d in ipairs(self.outputModules) do - self.output[i] = d.output - end - end + self.output = self:applyOnModules(function(m) return m.output end, self.outputModules) return self.output end diff --git a/test-dagnn.lua b/test-dagnn.lua index 262ea6f..a45d636 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -5,6 +5,10 @@ require 'nn' require 'dagnn' +-- torch.setnumthreads(params.nbThreads) +torch.setdefaulttensortype('torch.DoubleTensor') +torch.manualSeed(2) + a = nn.Linear(10, 10) b = nn.ReLU() c = nn.Linear(10, 3) @@ -24,14 +28,16 @@ f = nn.Linear(3, 2) g = nn.DAG:new() g:setInput(a) -g:setOutput({ e, f }) +g:setOutput({ e }) g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) -g:addEdge(d, f) +-- g:addEdge(d, f) + +-- g = torch.load('dag.t7') g:print() @@ -39,5 +45,12 @@ input = torch.Tensor(3, 10):uniform() output = g:updateOutput(input) -print(output[1]) -print(output[2]) +if torch.type(output) == 'table' then + for i, t in pairs(output) do + print(tostring(i) .. ' -> ' .. tostring(t)) + end +else + print(tostring(output)) +end + +torch.save('dag.t7', g) -- 2.39.5