From: Francois Fleuret Date: Thu, 10 Sep 2020 16:57:09 +0000 (+0200) Subject: Niiice. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=5db6798d929c15e1517ec10c1a9211f870ec977e;p=pytorch.git Niiice. --- diff --git a/flatparam.py b/flatparam.py index fbede34..e0627b2 100755 --- a/flatparam.py +++ b/flatparam.py @@ -24,8 +24,8 @@ def flatparam(model): n = sum(p.numel() for p in model.parameters()) whole = next(model.parameters()).new(n) # Get same device and dtype whole.requires_grad_() - _flatparam(model, whole, [], 0) - return whole + _flatparam(model, whole) + model.parameters = lambda: iter([ whole ]) ###################################################################### @@ -44,7 +44,7 @@ print('Before:') for p in model.parameters(): print(p.size(), p.storage().size()) -whole = flatparam(model) +flatparam(model) print('After:') for p in model.parameters(): @@ -56,7 +56,7 @@ print('Check:') input = torch.rand(100, 2) targets = torch.rand(100, 2) -optimizer = torch.optim.SGD([ whole ], lr = 1e-2) +optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2) mse = nn.MSELoss() for e in range(10):