Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 07:57:35 +0000 (09:57 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 07:57:35 +0000 (09:57 +0200)
main.py
tasks.py

diff --git a/main.py b/main.py
index 8033836..ebecad8 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -14,6 +14,14 @@ from torch.nn import functional as F
 import ffutils
 import mygpt, tasks
 
+# world quizzes vs. culture quizzes
+
+######################################################################
+
+accuracy_to_make_c_quizzes = 0.975
+nb_new_c_quizzes_for_train = 1000
+nb_new_c_quizzes_for_test = 100
+
 ######################################################################
 
 if torch.cuda.is_available():
@@ -84,6 +92,13 @@ if args.result_dir is None:
 
 ######################################################################
 
+if args.dirty_debug:
+    accuracy_to_make_c_quizzes = 0.0
+    nb_new_c_quizzes_for_train = 100
+    nb_new_c_quizzes_for_test = 10
+
+######################################################################
+
 default_args = {
     "model": "37M",
     "batch_size": 100,
@@ -329,7 +344,7 @@ def run_tests(model, task, deterministic_synthesis):
 ######################################################################
 
 
-def create_quizzes(
+def create_c_quizzes(
     model,
     other_models,
     task,
@@ -339,12 +354,12 @@ def create_quizzes(
 ):
     kept = []
 
-    sum_logits, sum_nb_quizzes = 0, 0
+    sum_logits, sum_nb_c_quizzes = 0, 0
 
     while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
         nb_to_generate = 4 * (nb_for_train + nb_for_test)
 
-        new_quizzes, nb_correct, average_logits = task.create_new_quizzes(
+        new_c_quizzes, nb_correct, average_logits = task.create_c_quizzes(
             n_epoch=n_epoch,
             result_dir=args.result_dir,
             logger=log_string,
@@ -354,33 +369,33 @@ def create_quizzes(
             desired_average_logits=desired_average_logits,
         )
 
-        sum_logits += new_quizzes.size(0) * average_logits
-        sum_nb_quizzes += new_quizzes.size(0)
+        sum_logits += new_c_quizzes.size(0) * average_logits
+        sum_nb_c_quizzes += new_c_quizzes.size(0)
 
-        to_keep = new_quizzes[nb_correct == len(other_models) - 1]
+        to_keep = new_c_quizzes[nb_correct == len(other_models) - 1]
 
         if args.dirty_debug:
-            to_keep = new_quizzes
+            to_keep = new_c_quizzes
 
         log_string(
-            f"keep {to_keep.size(0)}/{new_quizzes.size(0)} quizzes ({to_keep.size(0)*100/new_quizzes.size(0):.02f}%)"
+            f"keep {to_keep.size(0)}/{new_c_quizzes.size(0)} c_quizzes ({to_keep.size(0)*100/new_c_quizzes.size(0):.02f}%)"
         )
 
         kept.append(to_keep)
 
-    new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+    new_c_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
 
-    task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True)
-    task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False)
+    task.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
+    task.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
-    task.save_image(
-        new_quizzes[:72],
+    task.save_quizzes(
+        new_c_quizzes[:72],
         args.result_dir,
-        f"world_quiz_{n_epoch:04d}_{model.id:02d}.png",
+        f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}",
         log_string,
     )
 
-    return sum_logits / sum_nb_quizzes
+    return sum_logits / sum_nb_c_quizzes
 
 
 ######################################################################
@@ -410,15 +425,6 @@ log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 ######################################################################
 
-accuracy_to_make_quizzes = 0.975
-nb_new_quizzes_for_train = 1000
-nb_new_quizzes_for_test = 100
-
-if args.dirty_debug:
-    accuracy_to_make_quizzes = 0.0
-    nb_new_quizzes_for_train = 100
-    nb_new_quizzes_for_test = 10
-
 desired_average_logits = None
 
 for n_epoch in range(args.nb_epochs):
@@ -439,29 +445,29 @@ for n_epoch in range(args.nb_epochs):
     # improve it
     one_epoch(model, task)
 
-    task.renew_samples(args.nb_train_samples // args.nb_gpts)
+    task.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
 
     log_string(
-        f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
+        f"train_set_composition w_quizzes {task.nb_batch_w_quizzes} c_quizzes {task.nb_batch_c_quizzes}"
     )
 
     # test it
     run_tests(model, task, deterministic_synthesis=False)
 
     log_string(
-        f"test_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
+        f"test_set_composition w_quizzes {task.nb_batch_w_quizzes} c_quizzes {task.nb_batch_c_quizzes}"
     )
 
-    if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_quizzes:
+    if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_c_quizzes:
         other_models = models.copy()
         other_models.remove(model)
 
-        average_logits = create_quizzes(
+        average_logits = create_c_quizzes(
             model,
             other_models,
             task,
-            nb_for_train=nb_new_quizzes_for_train,
-            nb_for_test=nb_new_quizzes_for_test,
+            nb_for_train=nb_new_c_quizzes_for_train,
+            nb_for_test=nb_new_c_quizzes_for_test,
             desired_average_logits=desired_average_logits,
         )
 
index ee06c25..43f7d53 100755 (executable)
--- a/tasks.py
+++ b/tasks.py
@@ -88,6 +88,9 @@ class World(Task):
         torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
         logger(f"wrote {image_name}")
 
+    def save_quizzes(self, input, result_dir, filename_prefix, logger):
+        self.save_image(input, result_dir, filename_prefix + ".png", logger)
+
     def make_ar_mask(self, input):
         b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
         return b.long()[None, :].expand_as(input)
@@ -108,49 +111,52 @@ class World(Task):
         self.height = 6
         self.width = 8
 
-        self.train_input = world.generate_seq(
+        self.train_w_quizzes = world.generate_seq(
             nb_train_samples, height=self.height, width=self.width
         ).to(device)
 
-        self.test_input = world.generate_seq(
+        self.test_w_quizzes = world.generate_seq(
             nb_test_samples, height=self.height, width=self.width
         ).to(device)
 
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+        self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
 
-        self.train_quizzes = []
-        self.test_quizzes = []
+        self.train_c_quizzes = []
+        self.test_c_quizzes = []
 
         if result_dir is not None:
-            self.save_image(
-                self.train_input[:72], result_dir, f"world_train.png", logger
+            self.save_quizzes(
+                self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
             )
 
     def batches(self, split="train", desc=None):
         assert split in {"train", "test"}
         if split == "train":
-            input = self.train_input
-            quizzes = self.train_quizzes
+            w_quizzes = self.train_w_quizzes
+            c_quizzes = self.train_c_quizzes
         else:
-            input = self.test_input
-            quizzes = self.test_quizzes
+            w_quizzes = self.test_w_quizzes
+            c_quizzes = self.test_c_quizzes
 
-        if len(quizzes) > 0:
-            quizzes = torch.cat(quizzes, dim=0)
-            if quizzes.size(0) > input.size(0) // 2:
-                i = torch.randperm(input.size(0))[: input.size(0) // 2]
-                quizzes = quizzes[i]
+        if len(c_quizzes) > 0:
+            c_quizzes = torch.cat(c_quizzes, dim=0)
+            if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+                i = torch.randperm(w_quizzes.size(0))[: w_quizzes.size(0) // 2]
+                c_quizzes = c_quizzes[i]
 
-            i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
-            input = input[i]
+            i = torch.randperm(w_quizzes.size(0))[
+                : w_quizzes.size(0) - c_quizzes.size(0)
+            ]
+            w_quizzes = w_quizzes[i]
 
-            self.nb_batch_samples_world = input.size(0)
-            self.nb_batch_samples_quizzes = quizzes.size(0)
+            self.nb_batch_w_quizzes = w_quizzes.size(0)
+            self.nb_batch_c_quizzes = c_quizzes.size(0)
 
-            input = torch.cat([input, quizzes], dim=0)
+            input = torch.cat([w_quizzes, c_quizzes], dim=0)
         else:
-            self.nb_batch_samples_world = input.size(0)
-            self.nb_batch_samples_quizzes = 0
+            input = w_quizzes
+            self.nb_batch_w_quizzes = w_quizzes.size(0)
+            self.nb_batch_c_quizzes = 0
 
         # Shuffle
         input = input[torch.randperm(input.size(0))]
@@ -192,13 +198,13 @@ class World(Task):
 
             return nb_total, nb_correct
 
-        train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
+        train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
 
         logger(
             f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
         )
 
-        test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
+        test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes, logger)
 
         logger(
             f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
@@ -209,7 +215,7 @@ class World(Task):
 
         ##############################
 
-        input = self.test_input[:96]
+        input = self.test_w_quizzes[:96]
         ar_mask = self.make_ar_mask(input)
         result = input.clone() * (1 - ar_mask)
 
@@ -225,30 +231,30 @@ class World(Task):
             device=self.device,
         )
 
-        self.save_image(
+        self.save_quizzes(
             result[:72],
             result_dir,
-            f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
+            f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
             logger,
         )
 
         return main_test_accuracy
 
-    def renew_samples(self, nb, for_train=True):
-        input = self.train_input if for_train else self.test_input
+    def renew_w_quizzes(self, nb, for_train=True):
+        input = self.train_w_quizzes if for_train else self.test_w_quizzes
         nb = min(nb, input.size(0))
         input[:-nb] = input[nb:].clone()
         input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
             self.device
         )
 
-    def store_new_quizzes(self, new_quizzes, for_train=True):
+    def store_c_quizzes(self, new_c_quizzes, for_train=True):
         if for_train:
-            self.train_quizzes.append(new_quizzes)
+            self.train_c_quizzes.append(new_c_quizzes)
         else:
-            self.test_quizzes.append(new_quizzes)
+            self.test_c_quizzes.append(new_c_quizzes)
 
-    def create_new_quizzes(
+    def create_c_quizzes(
         self,
         n_epoch,
         result_dir,
@@ -261,11 +267,11 @@ class World(Task):
         ###############################################################
         # Generate quizzes with model
 
-        quizzes = torch.empty(
+        c_quizzes = torch.empty(
             nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
         )
 
-        ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+        ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
         summed_logits = torch.empty(nb, device=self.device)
 
         temperature = 1
@@ -277,12 +283,12 @@ class World(Task):
             masked_inplace_autoregression(
                 model=model,
                 batch_size=self.batch_size,
-                input=quizzes,
+                input=c_quizzes,
                 ar_mask=ar_mask,
                 summed_logits=summed_logits,
                 temperature=temperature,
                 deterministic_synthesis=False,
-                progress_bar_desc="creating quizzes",
+                progress_bar_desc="sampling c_quizzes",
                 device=self.device,
             )
 
@@ -311,15 +317,15 @@ class World(Task):
         # Create the reverse quizzes
 
         l = self.height * self.width
-        direction = quizzes[:, l : l + 1]
+        direction = c_quizzes[:, l : l + 1]
         direction = world.token_forward * (
             direction == world.token_backward
         ) + world.token_backward * (direction == world.token_forward)
-        reverse_quizzes = torch.cat(
-            [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
+        reverse_c_quizzes = torch.cat(
+            [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
         )
 
-        ar_mask = self.make_ar_mask(quizzes)
+        ar_mask = self.make_ar_mask(c_quizzes)
 
         ###############################################################
         # Check how many of the other models can solve them in both
@@ -328,7 +334,7 @@ class World(Task):
         nb_correct = []
 
         for m in other_models:
-            result = quizzes.clone()
+            result = c_quizzes.clone()
 
             masked_inplace_autoregression(
                 model=m,
@@ -338,13 +344,13 @@ class World(Task):
                 summed_logits=None,
                 temperature=1.0,
                 deterministic_synthesis=True,
-                progress_bar_desc="solving quizzes",
+                progress_bar_desc="solving c_quizzes",
                 device=self.device,
             )
 
-            correct = (quizzes == result).long().min(dim=-1).values
+            correct = (c_quizzes == result).long().min(dim=-1).values
 
-            reverse_result = reverse_quizzes.clone()
+            reverse_result = reverse_c_quizzes.clone()
 
             masked_inplace_autoregression(
                 model=m,
@@ -354,21 +360,21 @@ class World(Task):
                 summed_logits=None,
                 temperature=1.0,
                 deterministic_synthesis=True,
-                progress_bar_desc="solving reversed quizzes",
+                progress_bar_desc="solving reversed c_quizzes",
                 device=self.device,
             )
 
             reverse_correct = (
-                (reverse_quizzes == reverse_result).long().min(dim=-1).values
+                (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
             )
 
             nb_correct.append((correct * reverse_correct)[None, :])
 
-        nb_correct = torch.cat(nb_correct, dim=0)
+        nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0)
 
         # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
         # with open(filename, "w") as f:
         # for k in nb_correct:
         # f.write(f"{k}\n")
 
-        return quizzes, nb_correct.sum(dim=0), summed_logits.mean()
+        return c_quizzes, nb_correct, summed_logits.mean()