From a2346746c9b417eaf97aad87ed31dea92c3bb887 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Tue, 25 Jun 2024 20:38:31 +0200 Subject: [PATCH] Update. --- sky.py | 83 ++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 37 deletions(-) diff --git a/sky.py b/sky.py index 1e6ed4d..3584beb 100755 --- a/sky.py +++ b/sky.py @@ -258,31 +258,34 @@ class Sky(problem.Problem): return torch.cat(result, dim=0) - def frame2img(self, x, upscale=15): + def frame2img(self, x, scale=15): x = x.reshape(-1, self.height, self.width) m = torch.logical_and( x >= 0, x < self.first_bird_token + self.nb_bird_tokens ).long() x = self.colors[x * m].permute(0, 3, 1, 2) s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, upscale, -1, upscale) - x = x.reshape(s[0], s[1], s[2] * upscale, s[3] * upscale) + x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - x[:, :, :, torch.arange(0, x.size(3), upscale)] = 0 - x[:, :, torch.arange(0, x.size(2), upscale), :] = 0 + x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 + x[:, :, torch.arange(0, x.size(2), scale), :] = 0 x = x[:, :, 1:, 1:] for n in range(m.size(0)): for i in range(m.size(1)): for j in range(m.size(2)): if m[n, i, j] == 0: - for k in range(2, upscale - 2): - x[n, :, i * upscale + k, j * upscale + k] = 0 - x[n, :, i * upscale + upscale - 1 - k, j * upscale + k] = 0 + for k in range(2, scale - 2): + for l in [0, 1]: + x[n, :, i * scale + k, j * scale + k - l] = 0 + x[ + n, :, i * scale + scale - 1 - k, j * scale + k - l + ] = 0 return x - def seq2img(self, seq, upscale=15): + def seq2img(self, seq, scale=15): f_first = seq[:, : self.height * self.width].reshape( -1, self.height, self.width ) @@ -292,47 +295,53 @@ class Sky(problem.Problem): direction = seq[:, self.height * self.width] direction_symbol = torch.full( - (direction.size(0), self.height * upscale - 1, upscale), 0 + (direction.size(0), self.height * scale - 1, scale), 0 ) direction_symbol = self.colors[direction_symbol].permute(0, 3, 1, 2) - separator = torch.full((direction.size(0), 3, self.height * upscale - 1, 1), 0) + separator = torch.full((direction.size(0), 3, self.height * scale - 1, 1), 0) for n in range(direction_symbol.size(0)): if direction[n] == self.token_forward: - for k in range(upscale): - direction_symbol[ - n, - :, - (self.height * upscale) // 2 - upscale // 2 + k, - 3 + upscale // 2 - abs(k - upscale // 2), - ] = 0 + for k in range(scale): + for l in [0, 1]: + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + 3 + scale // 2 - abs(k - scale // 2), + ] = 0 elif direction[n] == self.token_backward: - for k in range(upscale): - direction_symbol[ - n, - :, - (self.height * upscale) // 2 - upscale // 2 + k, - 3 + abs(k - upscale // 2), - ] = 0 + for k in range(scale): + for l in [0, 1]: + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + 3 + abs(k - scale // 2), + ] = 0 else: - for k in range(2, upscale - 2): - direction_symbol[ - n, :, (self.height * upscale) // 2 - upscale // 2 + k, k - ] = 0 - direction_symbol[ - n, - :, - (self.height * upscale) // 2 - upscale // 2 + k, - upscale - 1 - k, - ] = 0 + for k in range(2, scale - 2): + for l in [0, 1]: + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + k, + ] = 0 + direction_symbol[ + n, + :, + (self.height * scale) // 2 - scale // 2 + k - l, + scale - 1 - k, + ] = 0 return torch.cat( [ - self.frame2img(f_first, upscale), + self.frame2img(f_first, scale), separator, direction_symbol, separator, - self.frame2img(f_second, upscale), + self.frame2img(f_second, scale), ], dim=3, ) -- 2.39.5