Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 17:41:08 +0000 (19:41 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 17:41:08 +0000 (19:41 +0200)
main.py
problem.py
quizz_machine.py
sky.py

diff --git a/main.py b/main.py
index 524715a..402e6e5 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -396,7 +396,6 @@ def create_c_quizzes(
         new_c_quizzes[:72],
         args.result_dir,
         f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}",
-        log_string,
     )
 
     return sum_logits / sum_nb_c_quizzes
index 8d973eb..95a9c41 100755 (executable)
@@ -13,7 +13,7 @@ class Problem:
         pass
 
     # save a file to vizualize quizzes, you can save a txt or png file
-    def save_quizzes(self, input, result_dir, filename_prefix, logger):
+    def save_quizzes(self, input, result_dir, filename_prefix):
         pass
 
     # returns a pair (forward_tokens, backward_token)
index d63855c..be34847 100755 (executable)
@@ -98,7 +98,7 @@ class QuizzMachine:
 
         if result_dir is not None:
             self.problem.save_quizzes(
-                self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
+                self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes"
             )
 
     def batches(self, split="train", desc=None):
@@ -206,10 +206,7 @@ class QuizzMachine:
         )
 
         self.problem.save_quizzes(
-            result[:72],
-            result_dir,
-            f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
-            logger,
+            result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}"
         )
 
         return main_test_accuracy
diff --git a/sky.py b/sky.py
index ec476a6..1e6ed4d 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -343,14 +343,13 @@ class Sky(problem.Problem):
             result.append("".join([self.token2char[v] for v in s]))
         return result
 
-    def save_image(self, input, result_dir, filename, logger):
+    def save_image(self, input, result_dir, filename):
         img = self.seq2img(input.to("cpu"))
         image_name = os.path.join(result_dir, filename)
         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
-        logger(f"wrote {image_name}")
 
-    def save_quizzes(self, input, result_dir, filename_prefix, logger):
-        self.save_image(input, result_dir, filename_prefix + ".png", logger)
+    def save_quizzes(self, input, result_dir, filename_prefix):
+        self.save_image(input, result_dir, filename_prefix + ".png")
 
 
 ######################################################################