From: Francois Fleuret Date: Wed, 11 Jan 2017 08:30:33 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=da3a60ffa7e1a39e4d01b405c2d80d84c3722c2c;p=dagnn.git Update. --- diff --git a/dagnn.lua b/dagnn.lua index 65a30e2..a6414b3 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -25,26 +25,46 @@ function DAG:addEdge(a, b) succ[a][#succ[a] + 1] = b end -function DAG:applyOnModules(f, t1, t2) - if torch.type(t1) == 'table' then +-- Apply f on t recursively; use the corresponding a1 and a2 elements +-- (i.e. same keys) as second and third parameters to f when +-- available; return the results from f, organized in a similarly +-- nested table. +function DAG:applyOnModules(f, t, a1, a2) + if torch.type(t) == 'table' then local result = {} - for k, s in pairs(t1) do - result[k] = self:applyOnModules(f, s, t2 and t2[k]) + for k, s in pairs(t) do + result[k] = self:applyOnModules(f, s, a1 and a1[k], a2 and a2[k]) end return result else - return f(t1, t2) + return f(t, a1, a2) end end function DAG:setInput(i) self.sorted = nil self.inputModules = i + self:applyOnModules( + function(m) + if (not self.succ[m] or #self.succ[m] == 0) or (self.pred[m] and #self.pred[m] > 0) then + error('Invalid input edges.') + end + end, + self.inputModules + ) end function DAG:setOutput(o) self.sorted = nil self.outputModules = o + self:applyOnModules( + function(m) + if (not self.pred[m] or #self.pred[m] == 0) or (self.succ[m] and #self.succ[m] > 0) then + error('Invalid output edges.') + end + end, + self.outputModules + ) end function DAG:sort() @@ -113,6 +133,31 @@ end function DAG:updateGradInput(input, gradOutput) self:sort() + + self:applyOnModules(function(m, i, go) m:updateGradInput(i, go) end, self.outputModules, input, gradOutput) + + for k = self.sorted, 1, -1 do + local m = sorted[k] + if self.succ[d] then + if #self.succ[d] == 1 then + d:updateGradInput(self.succ[d][1].gradInput) + elseif #self.succ[d] > 1 then + local sum + for k = 1, #self.succ[d] do + if sum then + sum:add(self.succ[d][k].gradInput) + else + sum = self.succ[d][k].gradInput:clone() + end + end + d:updateGradInput(sum) + end + end + end + + self.gradInput = self:applyOnModules(function(m) return m.gradInput end, self.inputModules) + + return self.gradInput end return DAG diff --git a/test-dagnn.lua b/test-dagnn.lua index a45d636..6c09f95 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -5,6 +5,17 @@ require 'nn' require 'dagnn' +function printTensorTable(t) + if torch.type(t) == 'table' then + for i, t in pairs(t) do + print('-- ELEMENT [' .. i .. '] --') + printTensorTable(t) + end + else + print(tostring(t)) + end +end + -- torch.setnumthreads(params.nbThreads) torch.setdefaulttensortype('torch.DoubleTensor') torch.manualSeed(2) @@ -16,41 +27,32 @@ d = nn.Linear(10, 3) e = nn.CMulTable() f = nn.Linear(3, 2) ---[[ - - a -----> b ---> c ----> e --- - \ / - \--> d ---/ - \ - \---> f --- -]]-- +-- a -----> b ---> c ----> e --- +-- \ / +-- \--> d ---/ +-- \ +-- \---> f --- -g = nn.DAG:new() - -g:setInput(a) -g:setOutput({ e }) +g = nn.DAG() g:addEdge(c, e) g:addEdge(a, b) g:addEdge(d, e) g:addEdge(b, c) g:addEdge(b, d) --- g:addEdge(d, f) +g:addEdge(d, f) --- g = torch.load('dag.t7') +g:setInput({a}) +g:setOutput({e,f}) g:print() input = torch.Tensor(3, 10):uniform() -output = g:updateOutput(input) +output = g:updateOutput({input}) -if torch.type(output) == 'table' then - for i, t in pairs(output) do - print(tostring(i) .. ' -> ' .. tostring(t)) - end -else - print(tostring(output)) -end +printTensorTable(output) + +---------------------------------------------------------------------- -torch.save('dag.t7', g) +-- gradInput = g:updateGradInput({ input }, output)