class QuizzMachine:
- def save_image(self, input, result_dir, filename, logger):
- img = self.sky.seq2img(input.to("cpu"))
- image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
- logger(f"wrote {image_name}")
-
- def save_quizzes(self, input, result_dir, filename_prefix, logger):
- self.save_image(input, result_dir, filename_prefix + ".png", logger)
-
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
):
super().__init__()
- self.sky = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
+ self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
self.batch_size = batch_size
self.device = device
- self.train_w_quizzes = self.sky.generate_seq(nb_train_samples).to(device)
- self.test_w_quizzes = self.sky.generate_seq(nb_test_samples).to(device)
+ self.train_w_quizzes = self.problem.generate_seq(nb_train_samples).to(device)
+ self.test_w_quizzes = self.problem.generate_seq(nb_test_samples).to(device)
self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
self.test_c_quizzes = []
if result_dir is not None:
- self.save_quizzes(
+ self.problem.save_quizzes(
self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
)
device=self.device,
)
- self.save_quizzes(
+ self.problem.save_quizzes(
result[:72],
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
input = self.train_w_quizzes if for_train else self.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
- input[-nb:] = self.sky.generate_seq(nb).to(self.device)
+ input[-nb:] = self.problem.generate_seq(nb).to(self.device)
def store_c_quizzes(self, new_c_quizzes, for_train=True):
if for_train:
###############################################################
# Create the reverse quizzes
+ token_forward, token_backward = self.problem.direction_tokens()
+
l = (c_quizzes.size(1) - 1) // 2
direction = c_quizzes[:, l : l + 1]
- direction = self.sky.token_forward * (
- direction == self.sky.token_backward
- ) + self.sky.token_backward * (direction == self.sky.token_forward)
+ direction = self.problem.token_forward * (
+ direction == self.problem.token_backward
+ ) + self.problem.token_backward * (direction == self.problem.token_forward)
reverse_c_quizzes = torch.cat(
[c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
)
# Written by Francois Fleuret <francois@fleuret.org>
-import math, sys, tqdm
+import math, sys, tqdm, os
import torch, torchvision
######################################################################
+class Problem:
+ def generate_seq(self, nb_train_samples):
+ pass
+
+ def save_quizzes(self, input, result_dir, filename_prefix, logger):
+ pass
+
+ def direction_tokens(self):
+ pass
+
+
class Sky:
colors = torch.tensor(
[
self.nb_birds = nb_birds
self.nb_iterations = nb_iterations
+ def direction_tokens(self):
+ return self.token_forward, self.token_backward
+
def generate_seq(self, nb, return_iterations=False):
pairs = []
kept_iterations = []
result.append("".join([self.token2char[v] for v in s]))
return result
+ def save_image(self, input, result_dir, filename, logger):
+ img = self.seq2img(input.to("cpu"))
+ image_name = os.path.join(result_dir, filename)
+ torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
+ logger(f"wrote {image_name}")
+
+ def save_quizzes(self, input, result_dir, filename_prefix, logger):
+ self.save_image(input, result_dir, filename_prefix + ".png", logger)
+
######################################################################