From: Francois Fleuret Date: Sat, 14 Jan 2017 16:04:06 +0000 (+0100) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=c1de441737656e7268f65116218cec6cdb304757;p=dagnn.git Update. --- diff --git a/dagnn.lua b/dagnn.lua index 0073e39..5921c05 100755 --- a/dagnn.lua +++ b/dagnn.lua @@ -86,14 +86,16 @@ function DAG:putInOrder() for i, a in ipairs(self.sorted) do self.sorted[i] = a.nnm end end --- This accumulate x in a where they are both nested tables of --- tensors. If first is true, set a = x. +-- This accumulates x in a where they are both nested tables of +-- tensors. If first is true, set a = x. Behavior is undefined if a +-- and x do not have the exact same structure. function DAG:nestedAccTensor(a, x, first) if torch.type(x) == 'table' then - a = a or {} + local b = {} for i in pairs(x) do - a[i] = self:nestedAccTensor(a[i], x[i], first) + b[i] = self:nestedAccTensor(a[i], x[i], first) end + a = b else if first then if a then @@ -222,8 +224,9 @@ function DAG:updateOutput(input) self:nestedApply( function(nnm, i) - self.node[nnm].input = i - self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i) + local node = self.node[nnm] + node.input = i + self:rethrowErrors(nnm, node.index, 'updateOutput', i) end, self.inputModules, input @@ -242,7 +245,7 @@ function DAG:updateOutput(input) end end node.input = i - self:rethrowErrors(nnm, self.node[nnm].index, 'updateOutput', i) + self:rethrowErrors(nnm, node.index, 'updateOutput', i) end end @@ -261,7 +264,7 @@ function DAG:updateGradInput(input, gradOutput) function(nnm, go) local node = self.node[nnm] node.gradOutput = go - self:rethrowErrors(nnm, node.index, 'updateGradInput', self.node[nnm].input, go) + self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, go) end, self.outputModules, gradOutput ) @@ -282,7 +285,7 @@ function DAG:updateGradInput(input, gradOutput) if #node.gradInputSucc > 0 then self:updateGradOutput(node) - self:rethrowErrors(nnm, self.node[nnm].index, 'updateGradInput', node.input, node.gradOutput) + self:rethrowErrors(nnm, node.index, 'updateGradInput', node.input, node.gradOutput) end -- We fill the gradInputSucc of our predecessors @@ -304,8 +307,6 @@ function DAG:updateGradInput(input, gradOutput) end function DAG:accGradParameters(input, gradOutput, scale) - scale = scale or 1 - assert(self.sorted, 'There has been a DAG structure change before a DAG:accGradParameters') self:nestedApply(