Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 17:50:13 +0000 (19:50 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 19 Aug 2024 17:50:13 +0000 (19:50 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index d98031e..1cbff39 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -61,8 +61,6 @@ parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
 
 parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
 
-parser.add_argument("--c_quiz_multiplier", type=int, default=1)
-
 parser.add_argument("--learning_rate", type=float, default=5e-4)
 
 parser.add_argument("--lambda_H", type=float, default=0.0)
@@ -342,7 +340,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
         nb_samples_accumulated = 0
 
         full_input, full_mask_loss = quiz_machine.data_input(
-            args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier
+            args.nb_test_samples, model.test_c_quiz_bags
         )
         src = zip(
             full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
@@ -370,9 +368,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
 
         log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
 
-        input, _ = quiz_machine.data_input(
-            2000, model.test_c_quiz_bags, args.c_quiz_multiplier
-        )
+        input, _ = quiz_machine.data_input(2000, model.test_c_quiz_bags)
 
         model.test_accuracy = quiz_machine.produce_results(
             n_epoch=n_epoch,
@@ -395,7 +391,7 @@ def one_epoch(model, quiz_machine, local_device=main_device):
     nb_train_samples, acc_train_loss = 0, 0.0
 
     full_input, full_mask_loss = quiz_machine.data_input(
-        args.nb_train_samples, model.train_c_quiz_bags, args.c_quiz_multiplier
+        args.nb_train_samples, model.train_c_quiz_bags
     )
     src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
 
@@ -561,26 +557,21 @@ def model_proba_solutions(model, quizzes):
     return l.exp()
 
 
-def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
+def create_c_quizzes(main_model, other_models, quiz_machine, nb_for_train, nb_for_test):
     nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
     nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
 
     start_time = time.perf_counter()
 
-    for model in models:
-        model.recorded_c_quizzes = []
-
-    teaching_count = torch.zeros(len(models), len(models), dtype=torch.int64)
+    recorded = []
 
     while nb_validated < nb_to_validate:
-        model_for_generation = models[torch.randint(len(models), (1,)).item()]
-
         # We generate quizzes with a procedure that injects some
         # structured noise
 
         c_quizzes = quiz_machine.generate_c_quizzes(
             nb_to_generate_per_iteration,
-            model_for_generation=model,
+            model_for_generation=main_model,
             procedure=c_quizzes_procedure,
         )
 
@@ -593,57 +584,48 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
 
         c_quizzes = c_quizzes[to_keep]
 
-        # Compute the responses of all the models on the c_quizzes,
-        # and their proba estimates of their responses
+        # Keep only the quizzes that the main model cannot solve
 
-        solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
+        solved_c_quizzes = c_quizzes.clone()
 
-        proba_own_solution = torch.zeros(
-            c_quizzes.size(0), len(models), device=solved_c_quizzes.device
+        main_solution, _, _ = quiz_machine.predict(
+            main_model,
+            solved_c_quizzes,
+            struct=("A", "f_A", "B", "f_B"),
+            mask=(0, 0, 0, 1),
         )
 
-        for model in models:
-            (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict(
-                model,
-                solved_c_quizzes[:, model.id],
-                struct=("A", "f_A", "B", "f_B"),
-                mask=(0, 0, 0, 1),
-            )
+        keep = (
+            model_proba_solutions(main_model, main_solution)
+            < args.proba_not_understands
+        )
+        c_quizzes = c_quizzes[keep]
+
+        # If there are some quizzes that the main model cannot solve,
+        # pick the most confident solution
 
-            proba_own_solution[:, model.id] = model_proba_solutions(
-                model, solved_c_quizzes[:, model.id]
+        if c_quizzes.size(0) > 0:
+            solution = c_quizzes.clone()
+            c_quizzes_proba = torch.zeros(
+                solution.size(0), dtype=torch.float32, device=solution.device
             )
 
-        # Now for every model not confident of its response, we pick
-        # the most consistent from a model which is confident
-
-        for s in range(proba_own_solution.size(0)):
-            # At least one GPT does not understand at all
-            if proba_own_solution[s, :].min() < args.proba_not_understands:
-                dont_get_this_quiz = proba_own_solution[s, :] < args.proba_understands
-                nb_fails = dont_get_this_quiz.long().sum()
-                # At most max_fail_to_validate do not understand (default 3/5)
-                if nb_fails >= 1 and nb_fails <= args.max_fail_to_validate:
-                    for model in models:
-                        # If a GPT does not get that quiz
-                        if dont_get_this_quiz[model.id]:
-                            assert (
-                                proba_own_solution[s, model.id] < args.proba_understands
-                            )
-                            # Look at its estimate of the others'solutions
-                            proba_other_solutions = model_proba_solutions(
-                                model, solved_c_quizzes[s]
-                            )
-                            # Randomize a bit the orders for the frequent P=1
-                            proba_other_solutions += (
-                                torch.rand(proba_other_solutions.size()) * 1e-6
-                            )
-                            # Remove the under threshold confidence solutions
-                            proba_other_solutions[dont_get_this_quiz] = -1
-                            i = proba_other_solutions.argmax()
-                            model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
-                            teaching_count[i, model.id] += 1
-                            nb_validated += 1
+            for model in other_models:
+                solution, _, _ = quiz_machine.predict(
+                    model,
+                    solution,
+                    struct=("A", "f_A", "B", "f_B"),
+                    mask=(0, 0, 0, 1),
+                )
+
+                probas = model_proba_solutions(model, solution)
+                keep = probas >= c_quizzes_proba
+                c_quizzes = solution[keep]
+                c_quizzes_proba[keep] = probas[keep]
+
+            keep = c_quizzes_proba >= args.proba_understands
+            recorded.append(c_quizzes_proba[keep])
+            nb_validated += keep.long().sum()
 
         duration = time.perf_counter() - start_time
 
@@ -662,146 +644,29 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
             f"keep c_quizzes model {model_for_generation.id} validated nb_validated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h) proportion_kept {nb_validated * 100 / nb_generated:.02f}%"
         )
 
-    for s in range(teaching_count.size(0)):
-        o = [x.item() for x in teaching_count[s]]
-        log_string(f"teacher model {s} to {o}")
+    # Save some images
 
-    for model in models:
-        new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0)
-
-        if new_bag.size(0) > 0:
-            n = (new_bag.size(0) * nb_for_train) // (nb_for_train + nb_for_test)
-            if n > 0:
-                model.train_c_quiz_bags.append(new_bag[:n])
-            if n < new_bag.size(0):
-                model.test_c_quiz_bags.append(new_bag[n:])
-
-            c_quizzes = new_bag[:128]
-
-            l = [model_proba_solutions(model, c_quizzes) for model in models]
-            probas = torch.cat([x[:, None] for x in l], dim=1)
-            comments = []
-
-            for l in probas:
-                comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
-
-            filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png"
-            quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir, filename, c_quizzes, comments=comments
-            )
+    c_quizzes = torch.cat(recorded, dim=0)
 
-        log_string(
-            f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
-        )
+    l = [
+        model_proba_solutions(model, c_quizzes) for model in [main_model] + other_models
+    ]
+    probas = torch.cat([x[:, None] for x in l], dim=1)
+    comments = []
+    for l in probas:
+        comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
 
+    filename = f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}.png"
+    quiz_machine.problem.save_quizzes_as_image(
+        args.result_dir, filename, c_quizzes[:128], comments=comments
+    )
 
-######################################################################
 
-from mygpt import (
-    WithResidual,
-    CacheWrapper,
-    AddPositionalEncoding,
-    QKVAttention,
-    BracketedSequence,
+log_string(
+    f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
 )
 
 
-class Thinker(nn.Module):
-    def __init__(
-        self,
-        vocabulary_size,
-        dim_model,
-        dim_keys,
-        dim_hidden,
-        nb_heads,
-        nb_blocks,
-        f_len,
-        dropout=0.0,
-        len_max=1e5,
-    ):
-        super().__init__()
-
-        assert dim_model % nb_heads == 0
-
-        self.embedding = nn.Sequential(
-            CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
-            AddPositionalEncoding(len_max),
-        )
-
-        def trunk(depth):
-            trunk_blocks = []
-
-            for b in range(nb_blocks):
-                trunk_blocks += [
-                    WithResidual(
-                        CacheWrapper(
-                            nn.LayerNorm((dim_model,)),
-                        ),
-                        QKVAttention(
-                            dim_in=dim_model,
-                            dim_qk=dim_keys,
-                            dim_v=dim_model // nb_heads,
-                            nb_heads=nb_heads,
-                            attention_dropout=dropout,
-                        ),
-                    ),
-                    WithResidual(
-                        CacheWrapper(
-                            nn.LayerNorm((dim_model,)),
-                            nn.Linear(in_features=dim_model, out_features=dim_hidden),
-                            nn.ReLU(),
-                            nn.Linear(in_features=dim_hidden, out_features=dim_model),
-                            nn.Dropout(dropout),
-                        ),
-                    ),
-                ]
-
-            return nn.Sequential(*trunk_blocks)
-
-        self.bottom_trunk = trunk(nb_blocks // 2)
-
-        self.top_trunk = trunk(nb_blocks // 2)
-
-        self.readout = CacheWrapper(
-            nn.Linear(in_features=dim_model, out_features=vocabulary_size)
-        )
-
-        self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model))
-
-        with torch.no_grad():
-            for m in self.modules():
-                if isinstance(m, nn.Embedding):
-                    m.weight.normal_(mean=0, std=2e-2)
-                elif isinstance(m, nn.LayerNorm):
-                    m.bias.zero_()
-                    m.weight.fill_(1.0)
-
-    def forward(self, bs):
-        for m in self.modules():
-            m.loss = 0
-
-        L = bs.x.size(1) // 3
-
-        bs = self.embedding(bs)
-        A_fA = BracketedSequence(bs.x[:, : 2 * L])
-        B = BracketedSequence(bs.x[:, -L:])
-
-        bs = BracketedSequence(
-            torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1)
-        )
-        bs = self.bottom_trunk(bs)
-        bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1))
-        bs = self.top_trunk(bs)
-        bs = BracketedSequence(bs.x[:, f_len:, :])
-        bs = self.readout(bs)
-
-        for m in self.modules():
-            if m is not self:
-                self.loss += m.loss
-
-        return bs
-
-
 ######################################################################
 
 models = []
@@ -855,20 +720,12 @@ for k in range(args.nb_gpts):
         model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
 
     model.test_accuracy = 0.0
-    model.best_test_accuracy = 0.0
-    model.best_dict = copy.deepcopy(model.state_dict())
     models.append(model)
 
 ######################################################################
 
 current_epoch = 0
 
-# We balance the computing time between training the models and
-# generating c_quizzes
-
-total_time_generating_c_quizzes = 0
-total_time_training_models = 0
-
 if args.resume:
     for model in models:
         filename = f"gpt_{model.id:03d}.pth"
@@ -878,8 +735,6 @@ if args.resume:
             model.load_state_dict(d["state_dict"])
             model.optimizer.load_state_dict(d["optimizer_state_dict"])
             model.test_accuracy = d["test_accuracy"]
-            model.best_test_accuracy = d["best_test_accuracy"]
-            model.best_dict = d["best_dict"]
             model.train_c_quiz_bags = d["train_c_quiz_bags"]
             model.test_c_quiz_bags = d["test_c_quiz_bags"]
             log_string(f"successfully loaded {filename}")
@@ -892,8 +747,6 @@ if args.resume:
         state = torch.load(os.path.join(args.result_dir, filename))
         log_string(f"successfully loaded {filename}")
         current_epoch = state["current_epoch"]
-        total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
-        total_time_training_models = state["total_time_training_models"]
     except FileNotFoundError:
         log_string(f"cannot find {filename}")
         pass
@@ -950,69 +803,6 @@ class Recorder(nn.Module):
         return input
 
 
-if args.test == "mlp":
-    model = models[0]
-    tape_input, tape_output = [], []
-    L = len(model.trunk)
-    model.trunk.insert(L // 2 + 1, Recorder(tape_output))
-    model.trunk.insert(L // 2, Recorder(tape_input))
-
-    mlp = nn.Sequential(
-        nn.Linear(args.dim_model, args.dim_model),
-        nn.ReLU(),
-        nn.Linear(args.dim_model, args.dim_model),
-        nn.ReLU(),
-        nn.Linear(args.dim_model, 8 * args.dim_model),
-        Folder(),
-        Unfolder(404, 8 * args.dim_model),
-        nn.Linear(8 * args.dim_model, args.dim_model),
-        nn.ReLU(),
-        nn.Linear(args.dim_model, args.dim_model),
-        nn.ReLU(),
-        nn.Linear(args.dim_model, args.dim_model),
-    ).to(main_device)
-
-    mlp.optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
-
-    for n_epoch in range(args.nb_epochs):
-        train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
-
-        tape_input.clear()
-        tape_output.clear()
-
-        with torch.autograd.no_grad():
-            model.to(main_device).eval()
-            for input in train_input.split(args.batch_size):
-                input = input.to(main_device)
-                output = model(mygpt.BracketedSequence(input)).x
-
-        train_input = torch.cat([bs.x for bs in tape_input], dim=0)
-        train_targets = torch.cat([bs.x for bs in tape_output], dim=0)
-
-        nb_train_samples, acc_train_loss = 0, 0.0
-        src = zip(
-            train_input.split(args.batch_size), train_targets.split(args.batch_size)
-        )
-        for input, targets in tqdm.tqdm(
-            src,
-            dynamic_ncols=True,
-            desc="train",
-            total=train_input.size(0) // args.batch_size,
-        ):
-            input = input.to(main_device)
-            output = mlp(input)
-            loss = F.mse_loss(output, targets) + output.abs().sum()
-            acc_train_loss += loss.item() * input.size(0)
-            nb_train_samples += input.size(0)
-
-            mlp.optimizer.zero_grad()
-            loss.backward()
-            mlp.optimizer.step()
-
-        log_string(f"mlp_loss {n_epoch} train {acc_train_loss/nb_train_samples}")
-
-    exit(0)
-
 ######################################################################
 
 
@@ -1057,67 +847,9 @@ def save_generated_c_quizzes(model, filename, nb=64):
 
 ######################################################################
 
-
-if args.test == "entropy":
-    model = models[0]
-    model.to(main_device)
-
-    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
-
-    log_string("starting testing entropy maximization")
-
-    for n_epoch in range(100):
-        input = quiz_machine.generate_c_quizzes(
-            128,
-            model_for_generation=model,
-            procedure=c_quizzes_procedure,
-        )
-
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir,
-            f"test_{n_epoch:04d}.png",
-            quizzes=input,
-        )
-
-        log_string(f"wrote {filename}")
-
-        with torch.no_grad():
-            for p in model.parameters():
-                p += torch.randn(p.size(), device=p.device) * 1e-3
-
-        # nb_train_samples, acc_train_loss = 0, 0.0
-
-        # for k in range(1000 // args.batch_size):
-        # input = quiz_machine.generate_c_quizzes(
-        # args.batch_size,
-        # model_for_generation=model,
-        # procedure=[(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), None)],
-        # )
-
-        # input = input.to(main_device)
-        # targets = input
-        # output = model(mygpt.BracketedSequence(input)).x
-        # loss = -F.cross_entropy(output.transpose(1, 2), targets)
-        # acc_train_loss += loss.item() * input.size(0)
-        # nb_train_samples += input.size(0)
-
-        # optimizer.zero_grad()
-        # loss.backward()
-        # optimizer.step()
-
-        # log_string(
-        # f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
-        # )
-
-    exit(0)
-
-######################################################################
-
 for n_epoch in range(current_epoch, args.nb_epochs):
     state = {
         "current_epoch": n_epoch,
-        "total_time_training_models": total_time_training_models,
-        "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
     }
     filename = "state.pth"
     torch.save(state, os.path.join(args.result_dir, filename))
@@ -1128,84 +860,71 @@ for n_epoch in range(current_epoch, args.nb_epochs):
     cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
     log_string(f"current_test_accuracies {cta}")
 
-    cta = " ".join([f"{float(m.best_test_accuracy):.04f}" for m in models])
-    log_string(f"current_best_test_accuracies {cta}")
-
     ##################################################
 
-    for model in models:
-        if model.test_accuracy >= args.accuracy_to_make_c_quizzes:
-            log_string(
-                f"storing_best model {model.id} accuracy {model.best_test_accuracy} -> {model.test_accuracy}"
-            )
-            model.best_dict = copy.deepcopy(model.state_dict())
-            model.best_test_accuracy = model.test_accuracy
-
-    # we restart
-    if total_time_generating_c_quizzes == 0:
-        total_time_training_models = 0
-
-    if (
-        min([m.best_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes
-        and total_time_training_models >= total_time_generating_c_quizzes
-    ):
-        for model in models:
-            model.current_dict = copy.deepcopy(model.state_dict())
-            model.load_state_dict(model.best_dict)
-
-        start_time = time.perf_counter()
+    if min([m.test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
         record_new_c_quizzes(
             models,
             quiz_machine,
             args.nb_new_c_quizzes_for_train,
             args.nb_new_c_quizzes_for_test,
         )
-        total_time_generating_c_quizzes += time.perf_counter() - start_time
 
-        # Force one epoch of training
         for model in models:
-            model.load_state_dict(model.current_dict)
+            new_model = mygpt.MyGPT(
+                vocabulary_size=vocabulary_size,
+                dim_model=args.dim_model,
+                dim_keys=args.dim_keys,
+                dim_hidden=args.dim_hidden,
+                nb_heads=args.nb_heads,
+                nb_blocks=args.nb_blocks,
+                compute_attzero=compute_causal_attzero,
+                dropout=args.dropout,
+            ).to(main_device)
+            model.load_state_dict(new_model.state_dict())
+            model.test_accuracy = 0.0
+            model.best_test_accuracy = 0.0
+            model.best_dict = copy.deepcopy(model.state_dict())
 
     ##################################################
     # Select, improve, and eval the worst model(s)
 
-    if total_time_training_models <= total_time_generating_c_quizzes:
-        ranked_models = sorted(
-            models,
-            # This ugly recipe will pick the worst if there some below
-            # args.accuracy_to_make_c_quizzes or one at random if they
-            # are all above
-            key=lambda m: float(
-                m.test_accuracy
-                if m.test_accuracy < args.accuracy_to_make_c_quizzes
-                else args.accuracy_to_make_c_quizzes + torch.rand(1).item()
-            ),
-        )
+    ranked_models = sorted(
+        models,
+        # This ugly recipe will pick the worst if there some below
+        # args.accuracy_to_make_c_quizzes or one at random if they
+        # are all above
+        key=lambda m: float(
+            m.test_accuracy
+            if m.test_accuracy < args.accuracy_to_make_c_quizzes
+            else args.accuracy_to_make_c_quizzes + torch.rand(1).item()
+        ),
+    )
 
-        weakest_models = ranked_models[: len(gpus)]
+    weakest_models = ranked_models[: len(gpus)]
 
-        threads = []
+    threads = []
 
-        start_time = time.perf_counter()
+    start_time = time.perf_counter()
 
-        for gpu, model in zip(gpus, weakest_models):
-            log_string(f"training model {model.id}")
+    for gpu, model in zip(gpus, weakest_models):
+        log_string(f"training model {model.id}")
 
-            t = threading.Thread(
-                target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
-            )
+        t = threading.Thread(
+            target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+        )
 
-            threads.append(t)
+        threads.append(t)
 
-            t.start()
+        t.start()
 
-        for t in threads:
-            t.join()
+    for t in threads:
+        t.join()
 
-        total_time_training_models += time.perf_counter() - start_time
+    total_time_training_models += time.perf_counter() - start_time
 
-        for model in weakest_models:
-            save_additional_results(n_epoch, model, models, c_quizzes_procedure)
+    for model in weakest_models:
+        save_additional_results(n_epoch, model, models, c_quizzes_procedure)
 
     # Save the models to disk
 
index a0b007a..1acd7ad 100755 (executable)
@@ -87,8 +87,6 @@ class QuizMachine:
             (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
             (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
             (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-            # (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-            # (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
             (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
         ]
 
@@ -140,23 +138,10 @@ class QuizMachine:
 
     ######################################################################
 
-    def data_input(self, nb_samples, c_quiz_bags, c_quiz_multiplier=1):
+    def data_input(self, nb_samples, c_quiz_bags):
         if len(c_quiz_bags) > 0:
             c_quizzes = torch.cat(c_quiz_bags, dim=0)
 
-            if c_quiz_multiplier > 1:
-                n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
-                body = c_quizzes.repeat(n, 1)
-                if n < c_quiz_multiplier:
-                    tail = c_quizzes[
-                        torch.randperm(c_quizzes.size(0))[
-                            : nb_samples // 2 - body.size(0)
-                        ]
-                    ]
-                    c_quizzes = torch.cat([body, tail], dim=0)
-                else:
-                    c_quizzes = body
-
             if c_quizzes.size(0) > nb_samples // 2:
                 i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
                 c_quizzes = c_quizzes[i]