Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 05:08:22 +0000 (07:08 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 05:08:22 +0000 (07:08 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index cd6e3a9..f51ab38 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -390,7 +390,9 @@ def run_tests(model, quiz_machine, local_device=main_device):
         nb_test_samples, acc_test_loss = 0, 0.0
         nb_samples_accumulated = 0
 
-        full_input, full_mask_loss = quiz_machine.data_input(model, split="test")
+        full_input, full_mask_loss = quiz_machine.data_input(
+            model, args.nb_test_samples
+        )
         src = zip(
             full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
         )
@@ -439,7 +441,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     hard_w_quizzes = []
 
-    full_input, full_mask_loss = quiz_machine.data_input(model, split="train")
+    full_input, full_mask_loss = quiz_machine.data_input(model, args.nb_train_samples)
     src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
 
     for input, mask_loss in tqdm.tqdm(
@@ -626,13 +628,6 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
     nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
 
     while nb_validated_per_model.sum() < nb_to_validate:
-        # We use the model that has generated the fewest quizzes to
-        # balance the number of quizzes per model overall
-
-        # model_for_generation = sorted(
-        # models, key=lambda m: nb_validated_per_model[m.id]
-        # )[0]
-
         model_for_generation = models[torch.randint(len(models), (1,)).item()]
 
         # We generate quizzes with a procedure that injects some
@@ -653,6 +648,18 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         # This is nb_quizzes x nb_models
 
+        solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
+
+        for m in models:
+            solved_c_quizzes[:, m.id] = quiz_machine.predict(
+                m,
+                solved_c_quizzes[:, m.id],
+                struct=("A", "f_A", "B", "f_B"),
+                mask=(0, 0, 0, 1),
+            )
+
+        # FINISH
+
         seq_logproba = quiz_machine.models_logprobas(
             models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
         ) + quiz_machine.models_logprobas(
@@ -1043,6 +1050,7 @@ for k in range(args.nb_gpts):
     ).to(main_device)
 
     model.id = k
+    model.c_quiz_bags = []
 
     if args.schedule_free:
         model.optimizer = schedulefree.AdamWScheduleFree(
@@ -1053,12 +1061,6 @@ for k in range(args.nb_gpts):
 
     model.main_test_accuracy = 0.0
 
-    model.train_w_quizzes = quiz_machine.problem.generate_w_quizzes(
-        args.nb_train_samples
-    )
-
-    model.test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
-
     models.append(model)
 
 ######################################################################
@@ -1312,7 +1314,7 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         record_new_c_quizzes(
             models,
             quiz_machine,
-            nb_for_train=args.nb_new_c_quizzes_for_train,
+            nb_errorsfor_train=args.nb_new_c_quizzes_for_train,
             nb_for_test=args.nb_new_c_quizzes_for_test,
         )
 
@@ -1366,11 +1368,6 @@ for n_epoch in range(current_epoch, args.nb_epochs):
 
     ######################################################################
 
-    # Renew the training samples
-
-    for model in weakest_models:
-        quiz_machine.renew_train_w_quizzes(model=model)
-
     if args.log_command is not None:
         s = args.log_command.split()
         s.insert(1, args.result_dir)
index 92da03d..1d89cf4 100755 (executable)
@@ -97,10 +97,6 @@ class QuizMachine:
 
         self.test_structures = self.train_structures
 
-        self.LOCK_C_QUIZZES = threading.Lock()
-        self.train_c_quizzes = []
-        self.test_c_quizzes = []
-
     def vocabulary_size(self):
         return self.problem.nb_token_values
 
@@ -150,40 +146,21 @@ class QuizMachine:
 
     ######################################################################
 
-    def data_input(self, model, split="train"):
-        assert split in {"train", "test"}
-
-        with self.LOCK_C_QUIZZES:
-            if split == "train":
-                w_quizzes = model.train_w_quizzes
-                c_quizzes = self.train_c_quizzes
-            else:
-                w_quizzes = model.test_w_quizzes
-                c_quizzes = self.test_c_quizzes
-
-            if len(c_quizzes) > 0:
-                c_quizzes = torch.cat(c_quizzes, dim=0)
+    def data_input(self, model, nb_samples):
+        if len(model.c_quiz_bags) > 0:
+            c_quizzes = torch.cat(model.c_quiz_bags, dim=0)
 
-                if c_quizzes.size(0) > w_quizzes.size(0) // 2:
-                    i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
-                    c_quizzes = c_quizzes[i]
+            if c_quizzes.size(0) > nb_samples // 2:
+                i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+                c_quizzes = c_quizzes[i]
 
-                i = torch.randperm(w_quizzes.size(0))[
-                    : w_quizzes.size(0) - c_quizzes.size(0)
-                ]
-                w_quizzes = w_quizzes[i]
-
-                quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
-                from_w = torch.arange(
-                    quizzes.size(0), device=quizzes.device
-                ) < w_quizzes.size(0)
-
-            else:
-                quizzes = w_quizzes.clone()
-                from_w = torch.full((quizzes.size(0),), True, device=quizzes.device)
+            w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+            quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+        else:
+            quizzes = self.problem.generate_w_quizzes(nb_samples)
 
         i = torch.randperm(quizzes.size(0), device=quizzes.device)
-        quizzes, from_w = quizzes[i], from_w[i]
+        quizzes = quizzes[i]
 
         self.randomize_configuations_inplace(
             quizzes, structs=[s for s, _, _, _ in self.train_structures]
@@ -292,38 +269,6 @@ class QuizMachine:
 
     ######################################################################
 
-    def renew_train_w_quizzes(self, model):
-        if hasattr(model, "hard_w_quizzes"):
-            hard_w_quizzes = self.problem.reconfigure(
-                model.hard_w_quizzes, struct=("A", "f_A", "B", "f_B")
-            )
-            self.logger(
-                f"re-using {hard_w_quizzes.size(0)} hard world quizzes from model {model.id}"
-            )
-            if hard_w_quizzes.size(0) >= model.train_w_quizzes.size(0):
-                nb_to_generate = 0
-                model.train_w_quizzes[...] = hard_w_quizzes[
-                    torch.randperm(hard_w_quizzes.size(0))[
-                        model.train_w_quizzes.size(0)
-                    ]
-                ]
-            else:
-                nb_to_generate = model.train_w_quizzes.size(0) - hard_w_quizzes.size(0)
-                model.train_w_quizzes[...] = torch.cat(
-                    [
-                        hard_w_quizzes,
-                        self.problem.generate_w_quizzes(nb_to_generate),
-                    ],
-                    dim=0,
-                )
-        else:
-            nb_to_generate = 0
-            model.train_w_quizzes[...] = self.problem.generate_w_quizzes(
-                model.train_w_quizzes.size(0)
-            )
-
-    ######################################################################
-
     def store_c_quizzes(self, new_c_quizzes, for_train=True):
         with self.LOCK_C_QUIZZES:
             if for_train: