Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 27 Aug 2024 19:33:46 +0000 (21:33 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 27 Aug 2024 19:33:46 +0000 (21:33 +0200)
main.py

diff --git a/main.py b/main.py
index e4237d2..85213cb 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -354,7 +354,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
             mask_loss = mask_loss.to(local_device)
             targets = input
 
-            output = model(mygpt.BracketedSequence(input)).x
+            output = model(input)
             loss_per_token = F.cross_entropy(
                 output.transpose(1, 2), targets, reduction="none"
             )
@@ -407,26 +407,11 @@ def one_epoch(model, quiz_machine, local_device=main_device):
             model.optimizer.zero_grad()
 
         targets = input
-
-        output = model(mygpt.BracketedSequence(input)).x
-
-        loss_per_token = F.cross_entropy(
-            output.transpose(1, 2), targets, reduction="none"
-        )
-
-        # warnings.warn("entropy masking", RuntimeWarning)
-        # l = output.transpose(1, 2).log_softmax(dim=1)
-        # H = -(l * l.exp()).sum(dim=1)
-        # M = (H >= -math.log(0.99) / H.size(1)).long()
-        # print(H, M)
-        # loss_per_token = loss_per_token * M
-
-        loss = (loss_per_token * mask_loss).mean() + model.loss
+        output = model(input)
+        loss = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
+        loss = (loss * mask_loss).mean() + model.loss
 
         acc_train_loss += loss.item() * input.size(0)
-
-        loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1)
-
         nb_train_samples += input.size(0)
 
         loss.backward()
@@ -477,28 +462,6 @@ data_structures = [
 ######################################################################
 
 
-def model_proba_solutions(model, quizzes):
-    l = (
-        quiz_machine.models_logprobas(
-            model, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, quizzes, ("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-        + quiz_machine.models_logprobas(
-            model, quizzes, ("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-    )
-
-    return l.exp()
-
-
-######################################################################
-
-
 def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
     nb_validated, nb_to_validate = 0, (nb_for_train + nb_for_test) * len(models)
     nb_generated, nb_to_generate_per_iteration = 0, nb_to_validate
@@ -853,6 +816,13 @@ class MyAttentionAE(nn.Module):
         return bs
 
 
+######################################################################
+
+nb_iterations = 25
+probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device)
+probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+
+
 def ae_batches(
     quiz_machine,
     nb,
@@ -952,6 +922,27 @@ def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50):
     return input
 
 
+######################################################################
+
+
+def model_ae_proba_solutions(model, input):
+    loss = 0
+
+    for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
+        mask_generate = quiz_machine.make_quiz_mask(
+            quizzes=input, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
+        )
+        targets, logits = targets_and_prediction(
+            probs_iterations, model, input, mask_generate
+        )
+        loss_per_token = F.cross_entropy(
+            logits.transpose(1, 2), targets, reduction="none"
+        )
+        loss += (loss_per_token * mask_generate).sum(dim=1)
+
+    return (-loss).exp()
+
+
 def degrade_input(input, mask_generate, nb_iterations, noise_proba):
     noise = torch.randint(
         quiz_machine.problem.nb_colors, input.size(), device=input.device
@@ -970,13 +961,31 @@ def degrade_input(input, mask_generate, nb_iterations, noise_proba):
     return result
 
 
-######################################################################
+def targets_and_prediction(probs_iterations, model, input, mask_generate):
+    d = deterministic(mask_generate)
+    p = probs_iterations.expand(input.size(0), -1)
+    dist = torch.distributions.categorical.Categorical(probs=p)
+    N0 = dist.sample()
+    N1 = N0 + 1
+    N0 = (1 - d) * N0
+    N1 = (1 - d) * N1 + d * nb_iterations
+
+    targets, input = degrade_input(
+        input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
+    )
 
+    input_with_mask = NTC_channel_cat(input, mask_generate)
+    logits = model(input_with_mask)
 
-def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
+    return targets, logits
+
+
+def run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=main_device):
     with torch.autograd.no_grad():
         model.eval().to(local_device)
 
+        # Compute the loss
+
         nb_test_samples, acc_test_loss = 0, 0.0
 
         for input, mask_generate, mask_loss in ae_batches(
@@ -986,18 +995,9 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
             local_device,
             "test",
         ):
-            d = deterministic(mask_generate)
-            p = probs_iterations.expand(input.size(0), -1)
-            dist = torch.distributions.categorical.Categorical(probs=p)
-            N0 = dist.sample()
-            N1 = N0 + 1
-            N0 = (1 - d) * N0
-            N1 = (1 - d) * N1 + d * nb_iterations
-            targets, input = degrade_input(
-                input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
+            targets, logits = targets_and_prediction(
+                probs_iterations, model, input, mask_generate
             )
-            input_with_mask = NTC_channel_cat(input, mask_generate)
-            logits = model(input_with_mask)
             loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
             acc_test_loss += loss.item() * input.size(0)
             nb_test_samples += input.size(0)
@@ -1006,10 +1006,9 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
             f"test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}"
         )
 
-        # -------------------------------------------
-        # Test generation
+        # Compute the accuracy and save some images
 
-        nb_correct, nb_total, record = 0, 0, []
+        nb_correct, nb_total, record_d, record_nd = 0, 0, [], []
 
         for input, mask_generate, mask_loss in ae_batches(
             quiz_machine,
@@ -1026,12 +1025,14 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
             predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
                 :, :, 1
             ]
-            solution_is_deterministic = predicted_parts.sum(dim=-1) == 1
-            correct = (2 * correct - 1) * (solution_is_deterministic).long()
+            d = predicted_parts.sum(dim=-1) == 1
+            correct = (2 * correct - 1) * d.long()
             nb_correct += (correct == 1).long().sum()
             nb_total += (correct != 0).long().sum()
             correct_parts = predicted_parts * correct[:, None]
-            record.append((result, predicted_parts, correct_parts))
+            record_d.append((result[d], predicted_parts[d], correct_parts[d]))
+            nd = d == False
+            record_nd.append((result[nd], predicted_parts[nd], correct_parts[nd]))
 
         log_string(
             f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
@@ -1039,27 +1040,34 @@ def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
 
         model.test_accuracy = nb_correct / nb_total
 
-        filename = f"prediction_ae_{n_epoch:04d}.png"
+        for f, record in [("prediction", record_d), ("generative", record_nd)]:
+            filename = f"culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+            result, predicted_parts, correct_parts = (
+                torch.cat([x[i] for x in record])[:128] for i in [0, 1, 2]
+            )
 
-        result, predicted_parts, correct_parts = (
-            torch.cat([x[i] for x in record]) for i in [0, 1, 2]
-        )
+            l = [model_ae_proba_solutions(model, result) for model in other_models]
+            probas = torch.cat([x[:, None] for x in l], dim=1)
+            comments = []
 
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir,
-            filename,
-            quizzes=result,
-            predicted_parts=predicted_parts,
-            correct_parts=correct_parts,
-        )
+            for l in probas:
+                comments.append("proba " + " ".join([f"{x.item():.02f}" for x in l]))
 
-        log_string(f"wrote {filename}")
+            quiz_machine.problem.save_quizzes_as_image(
+                args.result_dir,
+                filename,
+                quizzes=result,
+                predicted_parts=predicted_parts,
+                correct_parts=correct_parts,
+                comments=comments,
+            )
+            log_string(f"wrote {filename}")
 
 
 ######################################################################
 
 
-def one_ae_epoch(model, quiz_machine, n_epoch, local_device=main_device):
+def one_ae_epoch(model, other_models, quiz_machine, n_epoch, local_device=main_device):
     model.train().to(local_device)
 
     nb_train_samples, acc_train_loss = 0, 0.0
@@ -1074,20 +1082,9 @@ def one_ae_epoch(model, quiz_machine, n_epoch, local_device=main_device):
         if nb_train_samples % args.batch_size == 0:
             model.optimizer.zero_grad()
 
-        d = deterministic(mask_generate)
-        p = probs_iterations.expand(input.size(0), -1)
-        dist = torch.distributions.categorical.Categorical(probs=p)
-        N0 = dist.sample()
-        N1 = N0 + 1
-        N0 = (1 - d) * N0
-        N1 = (1 - d) * N1 + d * nb_iterations
-
-        targets, input = degrade_input(
-            input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
+        targets, logits = targets_and_prediction(
+            probs_iterations, model, input, mask_generate
         )
-
-        input_with_mask = NTC_channel_cat(input, mask_generate)
-        logits = model(input_with_mask)
         loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
         acc_train_loss += loss.item() * input.size(0)
         nb_train_samples += input.size(0)
@@ -1101,17 +1098,13 @@ def one_ae_epoch(model, quiz_machine, n_epoch, local_device=main_device):
         f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
     )
 
-    run_ae_test(model, quiz_machine, n_epoch, local_device=local_device)
+    run_ae_test(model, other_models, quiz_machine, n_epoch, local_device=local_device)
 
 
 ######################################################################
 
 noise_proba = 0.05
 
-nb_iterations = 25
-probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device)
-probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
-
 models = []
 
 for i in range(args.nb_models):
@@ -1194,6 +1187,9 @@ for n_epoch in range(args.nb_epochs):
 
     # --------------------------------------------------------------------
 
+    one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
+    exit(0)
+
     ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
     weakest_models = ranked_models[: len(gpus)]
 
@@ -1202,18 +1198,20 @@ for n_epoch in range(args.nb_epochs):
     start_time = time.perf_counter()
 
     for gpu, model in zip(gpus, weakest_models):
-        log_string(f"training model {model.id}")
+        log_string(f"training model {model.id} (accuracy {model.test_accuracy})")
 
         t = threading.Thread(
-            target=one_ae_epoch, daemon=True, args=(model, quiz_machine, n_epoch, gpu)
+            target=one_ae_epoch,
+            daemon=True,
+            args=(model, models, quiz_machine, n_epoch, gpu),
         )
 
         threads.append(t)
 
         t.start()
 
-        for t in threads:
-            t.join()
+    for t in threads:
+        t.join()
 
     # --------------------------------------------------------------------
 
@@ -1231,476 +1229,5 @@ for n_epoch in range(args.nb_epochs):
             },
             os.path.join(args.result_dir, filename),
         )
-        log_string(f"wrote {filename}")
-
-######################################################################
-
-
-def create_models():
-    models = []
-
-    def compute_causal_attzero(t_q, t_k):
-        return t_q < t_k
-
-    for k in range(args.nb_models):
-        log_string(f"creating model {k}")
-
-        model = mygpt.MyGPT(
-            vocabulary_size=vocabulary_size,
-            dim_model=args.dim_model,
-            dim_keys=args.dim_keys,
-            dim_hidden=args.dim_hidden,
-            nb_heads=args.nb_heads,
-            nb_blocks=args.nb_blocks,
-            compute_attzero=compute_causal_attzero,
-            dropout=args.dropout,
-        ).to(main_device)
-
-        class UpperBoundStd(nn.Module):
-            def __init__(self, std_max=1.0):
-                super().__init__()
-                self.std_max = std_max
-
-            def forward(self, x):
-                std = x.std(dim=-1, keepdim=True)
-                y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max)
-                return y
-
-        if args.logit_std_max > 0:
-            model.readout.f = nn.Sequential(
-                model.readout.f, UpperBoundStd(std_max=args.logit_std_max)
-            )
-
-        model.id = k
-        model.train_c_quiz_bags = []
-        model.test_c_quiz_bags = []
-
-        model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
-        model.test_accuracy = 0.0
-        model.gen_test_accuracy = 0.0
-        model.gen_state_dict = copy.deepcopy(model.state_dict())
-        models.append(model)
-
-    return models
-
-
-common_c_quiz_bags = []
-
-models = create_models()
-
-######################################################################
-
-# We balance the computing time between training the models and
-# generating c_quizzes
-
-total_time_generating_c_quizzes = 0
-total_time_training_models = 0
-
-current_epoch = 0
-
-if args.resume:
-    for model in models:
-        filename = f"gpt_{model.id:03d}.pth"
-
-        try:
-            d = torch.load(os.path.join(args.result_dir, filename))
-            model.load_state_dict(d["state_dict"])
-            model.optimizer.load_state_dict(d["optimizer_state_dict"])
-            model.test_accuracy = d["test_accuracy"]
-            model.gen_test_accuracy = d["gen_test_accuracy"]
-            model.gen_state_dict = d["gen_state_dict"]
-            model.train_c_quiz_bags = d["train_c_quiz_bags"]
-            model.test_c_quiz_bags = d["test_c_quiz_bags"]
-            log_string(f"successfully loaded {filename}")
-        except FileNotFoundError:
-            log_string(f"cannot find {filename}")
-            pass
-
-    try:
-        filename = "state.pth"
-        state = torch.load(os.path.join(args.result_dir, filename))
-        log_string(f"successfully loaded {filename}")
-        current_epoch = state["current_epoch"]
-        total_time_generating_c_quizzes = state["total_time_generating_c_quizzes"]
-        total_time_training_models = state["total_time_training_models"]
-        common_c_quiz_bags = state["common_c_quiz_bags"]
-    except FileNotFoundError:
-        log_string(f"cannot find {filename}")
-        pass
-
-######################################################################
-
-nb_parameters = sum(p.numel() for p in models[0].parameters())
-log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
-
-######################################################################
-
-if args.nb_new_c_quizzes_for_train is None:
-    args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
-
-if args.nb_new_c_quizzes_for_test is None:
-    args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
-
-log_string(
-    f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
-)
-
-######################################################################
-
-if args.dirty_debug:
-    args.accuracy_to_make_c_quizzes = 0.0
-    args.nb_models = 2
-    args.nb_new_c_quizzes_for_train = 100
-    args.nb_new_c_quizzes_for_test = 10
-
-######################################################################
-
-
-class Folder(nn.Module):
-    def forward(self, x):
-        return x.mean(dim=1)
-
-
-class Unfolder(nn.Module):
-    def __init__(self, T, dim):
-        super().__init__()
-        self.biases = nn.Parameter(torch.randn(T, dim))
-
-    def forward(self, x):
-        return x[:, None, :] + self.biases[None, :, :]
-
-
-class Recorder(nn.Module):
-    def __init__(self, tape):
-        super().__init__()
-        self.tape = tape
-
-    def forward(self, input):
-        self.tape.append(input)
-        return input
-
-
-if args.test == "mlp":
-    model = models[0]
-    tape_input, tape_output = [], []
-    L = len(model.trunk)
-    model.trunk.insert(L // 2 + 1, Recorder(tape_output))
-    model.trunk.insert(L // 2, Recorder(tape_input))
-
-    mlp = nn.Sequential(
-        nn.Linear(args.dim_model, args.dim_model),
-        nn.ReLU(),
-        nn.Linear(args.dim_model, args.dim_model),
-        nn.ReLU(),
-        nn.Linear(args.dim_model, 8 * args.dim_model),
-        Folder(),
-        Unfolder(404, 8 * args.dim_model),
-        nn.Linear(8 * args.dim_model, args.dim_model),
-        nn.ReLU(),
-        nn.Linear(args.dim_model, args.dim_model),
-        nn.ReLU(),
-        nn.Linear(args.dim_model, args.dim_model),
-    ).to(main_device)
-
-    mlp.optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
-
-    for n_epoch in range(args.nb_epochs):
-        train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
-
-        tape_input.clear()
-        tape_output.clear()
-
-        with torch.autograd.no_grad():
-            model.to(main_device).eval()
-            for input in train_input.split(args.batch_size):
-                input = input.to(main_device)
-                output = model(mygpt.BracketedSequence(input)).x
-
-        train_input = torch.cat([bs.x for bs in tape_input], dim=0)
-        train_targets = torch.cat([bs.x for bs in tape_output], dim=0)
-
-        nb_train_samples, acc_train_loss = 0, 0.0
-        src = zip(
-            train_input.split(args.batch_size), train_targets.split(args.batch_size)
-        )
-        for input, targets in tqdm.tqdm(
-            src,
-            dynamic_ncols=True,
-            desc="train",
-            total=train_input.size(0) // args.batch_size,
-        ):
-            input = input.to(main_device)
-            output = mlp(input)
-            loss = F.mse_loss(output, targets) + output.abs().sum()
-            acc_train_loss += loss.item() * input.size(0)
-            nb_train_samples += input.size(0)
-
-            mlp.optimizer.zero_grad()
-            loss.backward()
-            mlp.optimizer.step()
-
-        log_string(f"mlp_loss {n_epoch} train {acc_train_loss/nb_train_samples}")
-
-    exit(0)
-
-######################################################################
-
-
-def save_generated_c_quizzes(model, filename, nb=64):
-    while sum([x.size(0) for x in record]) < nb:
-        model = models[torch.randint(len(models), (1,)).item()]
-        c_quizzes = quiz_machine.generate_c_quizzes(
-            64,
-            model_for_generation=model,
-            procedure=c_quizzes_procedure,
-        )
-
-        p = quiz_machine.models_logprobas(
-            model,
-            c_quizzes,
-            ("A", "f_A", "B", "f_B"),
-            (1, 1, 1, 1),
-            temperature=1,
-        ).exp()
-
-        p_hot = quiz_machine.models_logprobas(
-            model,
-            c_quizzes,
-            ("A", "f_A", "B", "f_B"),
-            (1, 1, 1, 1),
-            temperature=args.temperature_hot,
-        ).exp()
-
-        to_keep = p_hot * torch.rand(p_hot.size(), device=p_hot.device) >= p
-        record.append(c_quizzes[to_keep])
-
-        print("NB_KEPT", sum([x.size(0) for x in record]))
-
-    quiz_machine.problem.save_quizzes_as_image(
-        args.result_dir,
-        filename,
-        quizzes=c_quizzes,
-    )
-
-    log_string(f"wrote {filename}")
-
-
-######################################################################
-
-
-if args.test == "entropy":
-    model = models[0]
-    model.to(main_device)
-
-    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
-
-    log_string("starting testing entropy maximization")
-
-    for n_epoch in range(100):
-        input = quiz_machine.generate_c_quizzes(
-            128,
-            model_for_generation=model,
-            procedure=c_quizzes_procedure,
-        )
-
-        filename = f"test_{n_epoch:04d}.png"
-
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir,
-            filename,
-            quizzes=input,
-        )
-
-        log_string(f"wrote {filename}")
-
-        with torch.no_grad():
-            for p in model.parameters():
-                p += torch.randn(p.size(), device=p.device) * 1e-3
-
-        # nb_train_samples, acc_train_loss = 0, 0.0
-
-        # for k in range(1000 // args.batch_size):
-        # input = quiz_machine.generate_c_quizzes(
-        # args.batch_size,
-        # model_for_generation=model,
-        # procedure=[(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), None)],
-        # )
-
-        # input = input.to(main_device)
-        # targets = input
-        # output = model(mygpt.BracketedSequence(input)).x
-        # loss = -F.cross_entropy(output.transpose(1, 2), targets)
-        # acc_train_loss += loss.item() * input.size(0)
-        # nb_train_samples += input.size(0)
-
-        # optimizer.zero_grad()
-        # loss.backward()
-        # optimizer.step()
-
-        # log_string(
-        # f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
-        # )
-
-    exit(0)
-
-######################################################################
-
-for n_epoch in range(current_epoch, args.nb_epochs):
-    state = {
-        "current_epoch": n_epoch,
-        "total_time_generating_c_quizzes": total_time_generating_c_quizzes,
-        "total_time_training_models": total_time_training_models,
-        "common_c_quiz_bags": common_c_quiz_bags,
-    }
-    filename = "state.pth"
-    torch.save(state, os.path.join(args.result_dir, filename))
-    log_string(f"wrote {filename}")
-
-    log_string(f"--- epoch {n_epoch} ----------------------------------------")
-
-    cta = " ".join([f"{float(m.test_accuracy):.04f}" for m in models])
-    log_string(f"current_test_accuracies {cta}")
-
-    cta = " ".join([f"{float(m.gen_test_accuracy):.04f}" for m in models])
-    log_string(f"current_gen_test_accuracies {cta}")
-
-    ##################################################
-
-    for model in models:
-        if model.test_accuracy >= args.accuracy_to_make_c_quizzes:
-            log_string(
-                f"storing_gen model {model.id} accuracy {model.gen_test_accuracy} -> {model.test_accuracy}"
-            )
-            model.gen_state_dict = copy.deepcopy(model.state_dict())
-            model.gen_test_accuracy = model.test_accuracy
-
-    # we restart
-    if total_time_generating_c_quizzes == 0:
-        total_time_training_models = 0
-
-    if min([m.gen_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
-        if args.reboot:
-            for model in models:
-                model.current_dict = copy.deepcopy(model.state_dict())
-                model.load_state_dict(model.gen_state_dict)
-
-            while True:
-                record_new_c_quizzes(
-                    models,
-                    quiz_machine,
-                    args.nb_new_c_quizzes_for_train,
-                    args.nb_new_c_quizzes_for_test,
-                )
-
-                nb_c_quizzes_per_model = [
-                    sum([x.size(0) for x in model.train_c_quiz_bags])
-                    for model in models
-                ]
-
-                p = tuple(
-                    f"{(x*100)/args.nb_train_samples:.02f}%"
-                    for x in nb_c_quizzes_per_model
-                )
-
-                log_string(f"nb_c_quizzes_per_model {p}")
-
-                m = max(nb_c_quizzes_per_model)
-
-                if m * args.c_quiz_multiplier >= args.nb_train_samples:
-                    break
-
-            model = models[nb_c_quizzes_per_model.index(m)]
-            common_c_quiz_bags.append(torch.cat(model.train_c_quiz_bags, dim=0))
-            nb_common_c_quizzes = sum([x.size(0) for x in common_c_quiz_bags])
-            log_string(
-                f"rebooting the models with {nb_common_c_quizzes} culture quizzes"
-            )
-
-            models = create_models()
-            total_time_generating_c_quizzes = 0
-            total_time_training_models = 0
-
-        elif total_time_training_models >= total_time_generating_c_quizzes:
-            for model in models:
-                model.current_dict = copy.deepcopy(model.state_dict())
-                model.load_state_dict(model.gen_state_dict)
-
-            start_time = time.perf_counter()
-
-            record_new_c_quizzes(
-                models,
-                quiz_machine,
-                args.nb_new_c_quizzes_for_train,
-                args.nb_new_c_quizzes_for_test,
-            )
-
-            total_time_generating_c_quizzes += time.perf_counter() - start_time
-
-            for model in models:
-                model.load_state_dict(model.current_dict)
-
-    ##################################################
-    # Select, improve, and eval the worst model(s)
-
-    if total_time_training_models <= total_time_generating_c_quizzes:
-        ranked_models = sorted(
-            models,
-            # This ugly recipe will pick the worst if there some below
-            # args.accuracy_to_make_c_quizzes or one at random if they
-            # are all above
-            key=lambda m: float(
-                m.test_accuracy
-                if m.test_accuracy < args.accuracy_to_make_c_quizzes
-                else args.accuracy_to_make_c_quizzes + torch.rand(1).item()
-            ),
-        )
-
-        weakest_models = ranked_models[: len(gpus)]
-
-        threads = []
-
-        start_time = time.perf_counter()
-
-        for gpu, model in zip(gpus, weakest_models):
-            log_string(f"training model {model.id}")
-
-            t = threading.Thread(
-                target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
-            )
-
-            threads.append(t)
-
-            t.start()
 
-        for t in threads:
-            t.join()
-
-        total_time_training_models += time.perf_counter() - start_time
-
-    # Save the models to disk
-
-    for model in models:
-        filename = f"gpt_{model.id:03d}.pth"
-        torch.save(
-            {
-                "state_dict": model.state_dict(),
-                "optimizer_state_dict": model.optimizer.state_dict(),
-                "test_accuracy": model.test_accuracy,
-                "gen_test_accuracy": model.gen_test_accuracy,
-                "gen_state_dict": model.gen_state_dict,
-                "train_c_quiz_bags": model.train_c_quiz_bags,
-                "test_c_quiz_bags": model.test_c_quiz_bags,
-            },
-            os.path.join(args.result_dir, filename),
-        )
         log_string(f"wrote {filename}")
-
-    ######################################################################
-
-    if args.log_command is not None:
-        s = args.log_command.split()
-        s.insert(1, args.result_dir)
-        subprocess.run(s)
-
-######################################################################