Update.
authorFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 08:55:09 +0000 (10:55 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Sun, 4 Aug 2024 08:55:09 +0000 (10:55 +0200)
main.py

diff --git a/main.py b/main.py
index 4dca41f..7361ae8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -364,14 +364,13 @@ def sigma_for_grids(input):
     l = input.size(1) // 4 - 1
     sigma = input.new(input.size())
     r = sigma.view(sigma.size(0), 4, sigma.size(1) // 4)
-    r[:, :, 0] = 0
-    r[:, :, 1:] = (
+    r[:, 0] = 0 * l
+    r[:, 1] = 1 * l
+    r[:, 2] = 2 * l
+    r[:, 3] = 3 * l
+    r[:, :, 1:] += (
         torch.rand(input.size(0), 4, l, device=input.device).sort(dim=2).indices
     ) + 1
-    r[:, 0] += 0 * l
-    r[:, 1] += 1 * l
-    r[:, 2] += 2 * l
-    r[:, 3] += 3 * l
     return sigma