From: Francois Fleuret Date: Thu, 12 Jan 2017 14:46:51 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=31dc42fc93ed12491ceb10ef3bfc4296878380ee;p=dagnn.git Update. --- diff --git a/dagnn.lua b/dagnn.lua index 05672e9..8a02cc6 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -152,6 +152,22 @@ function DAG:updateOutput(input) return self.output end +function DAG:computeGradInput(gradInputSucc) + local gi + if #gradInputSucc == 1 then + gi = gradInputSucc[1] -- we avoid a clone() + elseif #gradInputSucc > 1 then + for k = 1, #gradInputSucc do + if gi then + gi:add(gradInputSucc[k]) + else + gi = gradInputSucc[k]:clone() + end + end + end + return gi +end + function DAG:updateGradInput(input, gradOutput) self:putInOrder() @@ -160,6 +176,11 @@ function DAG:updateGradInput(input, gradOutput) self.outputModules, gradOutput ) + self:nestApply( + function(nnm, i) self.node[nnm].input = i end, + self.inputModules, input + ) + for _, node in pairs(self.node) do node.gradInputSucc = {} end @@ -167,23 +188,10 @@ function DAG:updateGradInput(input, gradOutput) for k = #self.sorted, 1, -1 do local nnm = self.sorted[k] local node = self.node[nnm] - local pred, succ, gradInputSucc = node.pred, node.succ, node.gradInputSucc + local pred, gradInputSucc = node.pred, node.gradInputSucc if #gradInputSucc > 0 then - -- We update nnm:gradInput - local gi - if #gradInputSucc == 1 then - gi = gradInputSucc[1] -- we avoid a clone() - elseif #gradInputSucc > 1 then - for k = 1, #gradInputSucc do - if gi then - gi:add(gradInputSucc[k]) - else - gi = gradInputSucc[k]:clone() - end - end - end - nnm:updateGradInput(node.input, gi) + nnm:updateGradInput(node.input, self:computeGradInput(gradInputSucc)) end -- We fill the gradInputSucc of our predecessors diff --git a/test-dagnn.lua b/test-dagnn.lua index 32eed57..d7179cc 100755 --- a/test-dagnn.lua +++ b/test-dagnn.lua @@ -58,6 +58,6 @@ printTensorTable(output) print('******************************************************************') print('** updateGradInput ***********************************************') print('******************************************************************') -gradInput = g:updateGradInput({ input }, output) +gradInput = g:updateGradInput({{input}}, output) printTensorTable(gradInput)