From 48e4bc80f388ea2932757fdca0688de7bf52ee06 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sun, 4 Aug 2024 18:47:44 +0200 Subject: [PATCH] Update. --- quiz_machine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/quiz_machine.py b/quiz_machine.py index 386969a..332cd86 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -94,14 +94,14 @@ class QuizMachine: ###################################################################### - def sigma_for_grids(self, input): + def sigma_for_grids(self, input, block_order=(0, 1, 2, 3)): l = input.size(1) // 4 sigma = input.new(input.size()) r = sigma.view(sigma.size(0), 4, l) - r[:, 0] = 0 * l - r[:, 1] = 1 * l - r[:, 2] = 2 * l - r[:, 3] = 3 * l + r[:, 0, :] = block_order[0] * l + r[:, 1, :] = block_order[1] * l + r[:, 2, :] = block_order[2] * l + r[:, 3, :] = block_order[3] * l r[:, :, 1:] += ( torch.rand(input.size(0), 4, l - 1, device=input.device).sort(dim=2).indices ) + 1 -- 2.39.5