projects
/
pytorch.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
150a02f
)
Update.
master
author
François Fleuret
<francois@fleuret.org>
Thu, 13 Jun 2024 18:05:29 +0000
(20:05 +0200)
committer
François Fleuret
<francois@fleuret.org>
Thu, 13 Jun 2024 18:05:29 +0000
(20:05 +0200)
redshift.py
patch
|
blob
|
history
diff --git
a/redshift.py
b/redshift.py
index
b3507ed
..
2ed1e52
100755
(executable)
--- a/
redshift.py
+++ b/
redshift.py
@@
-9,8
+9,10
@@
from torch.nn import functional as F
torch.set_default_dtype(torch.float64)
torch.set_default_dtype(torch.float64)
+nb_hidden = 5
+hidden_dim = 100
+
res = 256
res = 256
-nh = 100
input = torch.cat(
[
input = torch.cat(
[
@@
-28,11
+30,10
@@
class Angles(nn.Module):
for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]:
for s in [1.0, 10.0]:
for activation in [nn.ReLU, nn.Tanh, nn.Softplus, Angles]:
for s in [1.0, 10.0]:
- layers = [nn.Linear(2, nh), activation()]
- nb_hidden = 4
- for k in range(nb_hidden):
- layers += [nn.Linear(nh, nh), activation()]
- layers += [nn.Linear(nh, 2)]
+ layers = [nn.Linear(2, hidden_dim), activation()]
+ for k in range(nb_hidden - 1):
+ layers += [nn.Linear(hidden_dim, hidden_dim), activation()]
+ layers += [nn.Linear(hidden_dim, 2)]
model = nn.Sequential(*layers)
with torch.no_grad():
model = nn.Sequential(*layers)
with torch.no_grad():