Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 11:53:31 +0000 (13:53 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 11:53:31 +0000 (13:53 +0200)
quizz_machine.py [moved from tasks.py with 96% similarity]
sku.py [moved from world.py with 100% similarity]

similarity index 96%
rename from tasks.py
rename to quizz_machine.py
index 50ded2c..d8ebad8 100755 (executable)
--- a/tasks.py
@@ -82,12 +82,12 @@ class Task:
 
 ######################################################################
 
-import world
+import sky
 
 
 class QuizzMachine(Task):
     def save_image(self, input, result_dir, filename, logger):
-        img = world.seq2img(input.to("cpu"), self.height, self.width)
+        img = sky.seq2img(input.to("cpu"), self.height, self.width)
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
         logger(f"wrote {image_name}")
@@ -115,11 +115,11 @@ class QuizzMachine(Task):
         self.height = 6
         self.width = 8
 
-        self.train_w_quizzes = world.generate_seq(
+        self.train_w_quizzes = sky.generate_seq(
             nb_train_samples, height=self.height, width=self.width
         ).to(device)
 
-        self.test_w_quizzes = world.generate_seq(
+        self.test_w_quizzes = sky.generate_seq(
             nb_test_samples, height=self.height, width=self.width
         ).to(device)
 
@@ -250,7 +250,7 @@ class QuizzMachine(Task):
         input = self.train_w_quizzes if for_train else self.test_w_quizzes
         nb = min(nb, input.size(0))
         input[:-nb] = input[nb:].clone()
-        input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
+        input[-nb:] = sky.generate_seq(nb, height=self.height, width=self.width).to(
             self.device
         )
 
@@ -324,9 +324,9 @@ class QuizzMachine(Task):
 
         l = self.height * self.width
         direction = c_quizzes[:, l : l + 1]
-        direction = world.token_forward * (
-            direction == world.token_backward
-        ) + world.token_backward * (direction == world.token_forward)
+        direction = sky.token_forward * (
+            direction == sky.token_backward
+        ) + sky.token_backward * (direction == sky.token_forward)
         reverse_c_quizzes = torch.cat(
             [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
         )
diff --git a/world.py b/sku.py
similarity index 100%
rename from world.py
rename to sku.py