From: Francois Fleuret Date: Tue, 6 Dec 2016 08:07:06 +0000 (+0100) Subject: So, back to decorating the classes and not the objects so that torch.save() does... X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=b8c7166b9123735e8226d34b717d3cbc2dc1fa02;p=profiler-torch.git So, back to decorating the classes and not the objects so that torch.save() does not complain with SpatialConvolution. Added the possibility to pass the total time to profiler.print() so that the fraction of time used by the different functions can be displayed. --- diff --git a/profiler.lua b/profiler.lua index 34e180b..4e45787 100644 --- a/profiler.lua +++ b/profiler.lua @@ -50,9 +50,18 @@ function profiler.decorate(model, functionsToDecorate) local nameOrig = name .. '__orig' - if model[name] and not model[nameOrig] then - model[nameOrig] = model[name] - model[name] = function(self, ...) + -- We decorate the class and not the object, otherwise we cannot + -- save models anymore. + + if rawget(model, name) then + error('We decorate the class, not the objects, and there is a ' .. name .. ' in ' .. model) + end + + local toDecorate = getmetatable(model) + + if toDecorate[name] and not toDecorate[nameOrig] then + toDecorate[nameOrig] = toDecorate[name] + toDecorate[name] = function(self, ...) local startTime = sys.clock() local result = { self[nameOrig](self, unpack({...})) } local endTime = sys.clock() @@ -71,24 +80,27 @@ function profiler.decorate(model, functionsToDecorate) end -function profiler.print(model, nbSamples, indent) +function profiler.print(model, nbSamples, totalTime, indent) local indent = indent or '' print(string.format('%s* %s', indent, model.__typename)) for l, t in pairs(model.accTime) do - local s + local s = string.format('%s %s %.02fs', indent, l, t) + if totalTime then + s = s .. string.format(' [%.02f%%]', 100 * t / totalTime) + end if nbSamples then - s = string.format(' (%.01fmus/sample)', 1e6 * t / nbSamples) - else - s = '' + s = s .. string.format(' (%.01fmus/sample)', 1e6 * t / nbSamples) end - print(string.format('%s %s %.02fs%s', indent, l, t, s)) + print(s) end + print() + if torch.isTypeOf(model, nn.Container) then for _, m in ipairs(model.modules) do - profiler.print(m, nbSamples, indent .. ' ') + profiler.print(m, nbSamples, totalTime, indent .. ' ') end end end diff --git a/test-profiler.lua b/test-profiler.lua index a78c944..18677ec 100755 --- a/test-profiler.lua +++ b/test-profiler.lua @@ -39,9 +39,14 @@ require 'profiler' -- Create a model +local w, h, fs = 50, 50, 3 +local nhu = (w - fs + 1) * (h - fs + 1) + local model = nn.Sequential() :add(nn.Sequential() - :add(nn.Linear(1000, 1000)) + :add(nn.SpatialConvolution(1, 1, fs, fs)) + :add(nn.Reshape(nhu)) + :add(nn.Linear(nhu, 1000)) :add(nn.ReLU()) ) :add(nn.Linear(1000, 100)) @@ -55,7 +60,7 @@ torch.save('model.t7', model) -- Create the data and criterion -local input = torch.Tensor(1000, 1000) +local input = torch.Tensor(1000, 1, h, w) local target = torch.Tensor(input:size(1), 100) local criterion = nn.MSECriterion() @@ -88,7 +93,7 @@ end -- Print the accumulated timings -profiler.print(model, nbSamples) +profiler.print(model, nbSamples, modelTime) -- profiler.print(model) print()