Update.
authorFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 22:03:42 +0000 (00:03 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Mon, 12 Aug 2024 22:03:42 +0000 (00:03 +0200)
main.py
quiz_machine.py

diff --git a/main.py b/main.py
index e516a77..0a79323 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -391,7 +391,7 @@ def run_tests(model, quiz_machine, local_device=main_device):
         nb_samples_accumulated = 0
 
         full_input, full_mask_loss = quiz_machine.data_input(
-            model, args.nb_test_samples
+            args.nb_test_samples, model.test_c_quiz_bags
         )
         src = zip(
             full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
@@ -441,7 +441,9 @@ def one_epoch(model, quiz_machine, local_device=main_device):
 
     hard_w_quizzes = []
 
-    full_input, full_mask_loss = quiz_machine.data_input(model, args.nb_train_samples)
+    full_input, full_mask_loss = quiz_machine.data_input(
+        args.nb_train_samples, model.train_c_quiz_bags
+    )
     src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
 
     for input, mask_loss in tqdm.tqdm(
@@ -528,17 +530,23 @@ def save_additional_results(model, models, science_w_quizzes):
 
     # This is nb_quizzes x nb_models
 
-    seq_logproba = quiz_machine.models_logprobas(
-        models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-    ) + quiz_machine.models_logprobas(
-        models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-    )
+    l = [
+        quiz_machine.models_logprobas(
+            model, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
+        + quiz_machine.models_logprobas(
+            model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+        )
+        for model in models
+    ]
+
+    seq_logprobas = torch.cat([x[None, :] for x in l])
 
-    probas = seq_logproba.exp()
+    probas = seq_logprobas.exp()
 
     comments = []
 
-    for l in seq_logproba:
+    for l in seq_logprobas:
         comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
 
     ##
@@ -616,18 +624,26 @@ def save_additional_results(model, models, science_w_quizzes):
 ######################################################################
 
 
-def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
-    nb_to_validate = nb_for_train + nb_for_test
-    nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate)
-    nb_validated = 0
+def model_proba_solutions(m, quizzes):
+    l = quiz_machine.models_logprobas(
+        m, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+    ) + quiz_machine.models_logprobas(
+        m, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+    )
+
+    return l.exp()
+
 
-    recorded_validated = []
+def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
+    nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test
+    nb_to_generate_per_iteration = nb_to_validate
 
     start_time = time.perf_counter()
 
-    nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
+    for model in models:
+        model.recorded_c_quizzes = []
 
-    while nb_validated_per_model.sum() < nb_to_validate:
+    while nb_validated < nb_to_validate:
         model_for_generation = models[torch.randint(len(models), (1,)).item()]
 
         # We generate quizzes with a procedure that injects some
@@ -646,80 +662,48 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
 
         c_quizzes = c_quizzes[to_keep]
 
-        # This is nb_quizzes x nb_models
+        # Compute the responses of all the models on the c_quizzes,
+        # and their proba estimates of their responses
 
         solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
 
-        seq_logproba = torch.zeros(
+        proba_own_solution = torch.zeros(
             c_quizzes.size(0), len(models), device=solved_c_quizzes.device
         )
 
-        for m in models:
-            (
-                solved_c_quizzes[:, m.id],
-                _,
-                seq_logproba[:, m.id],
-            ) = quiz_machine.predict(
-                m,
-                solved_c_quizzes[:, m.id],
+        for model in models:
+            (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict(
+                model,
+                solved_c_quizzes[:, model.id],
                 struct=("A", "f_A", "B", "f_B"),
                 mask=(0, 0, 0, 1),
             )
 
-        #!!!!!!!!!!!!!!!!!!!!
-        for m in range(seq_logproba.size(1)):
-            l = quiz_machine.models_logprobas(
-                [models[m]],
-                solved_c_quizzes[:, m, :],
-                ("A", "f_A", "B", "f_B"),
-                (0, 0, 0, 1),
-                (0, 0, 0, 0),
-            )
-            for s in range(seq_logproba.size(0)):
-                print("DEBUG", seq_logproba[s, m].item(), l[s, 0].item())
-        exit(0)
-        #!!!!!!!!!!!!!!!!!!!!!!!!!
+            u = model_proba_solutions(model, solved_c_quizzes[:, model.id])
 
-        # FINISH
+            proba_own_solution[:, model.id] = u
 
-        seq_logproba = quiz_machine.models_logprobas(
-            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        ) + quiz_machine.models_logprobas(
-            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-
-        probas = seq_logproba.exp()
-
-        nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
-        nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
-
-        to_keep = (
-            # (nb_succeed + nb_fail == probas.size(1))
-            (nb_succeed >= args.min_succeed_to_validate)
-            & (nb_fail >= 1)
-            & (nb_fail <= args.max_fail_to_validate)
-        )
-
-        c_quizzes = c_quizzes[to_keep]
-
-        if c_quizzes.size(0) > 0:
-            nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
-            recorded_validated.append(c_quizzes)
-            nb_validated = c_quizzes.size(0)
-        else:
-            nb_validated = 0
+        # Now for every model not confident of its response, we pick
+        # the most consistent from a model which is confident
 
-        total_nb_validated = nb_validated_per_model.sum().item()
+        for s in range(proba_own_solution.size(0)):
+            dont_get_it = proba_own_solution[s, :] < args.proba_understands
+            if not dont_get_it.all():
+                for model in models:
+                    if dont_get_it[model.id]:
+                        proba_other_solutions = model_proba_solutions(
+                            model, solved_c_quizzes[s]
+                        )
+                        proba_other_solutions[dont_get_it] = -1
+                        i = proba_other_solutions.argmax()
+                        model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
+                        nb_validated += 1
 
         duration = time.perf_counter() - start_time
 
-        if total_nb_validated > 0:
-            if total_nb_validated < nb_to_validate:
-                d = (
-                    (nb_to_validate - total_nb_validated)
-                    * duration
-                    / total_nb_validated
-                )
+        if nb_validated > 0:
+            if nb_validated < nb_to_validate:
+                d = (nb_to_validate - nb_validated) * duration / nb_validated
                 e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
                     "%a %H:%M"
                 )
@@ -729,320 +713,44 @@ def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=10
             e = "???"
 
         log_string(
-            f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
-        )
-
-    validated_quizzes = torch.cat(recorded_validated, dim=0)
-
-    ######################################################################
-    # store the new c_quizzes which have been validated
-
-    v_train = validated_quizzes[:nb_for_train]
-    quiz_machine.store_c_quizzes(v_train, for_train=True)
-
-    v_test = validated_quizzes[nb_for_train:nb_to_validate]
-    quiz_machine.store_c_quizzes(v_test, for_train=False)
-
-    ######################################################################
-    # save images
-
-    vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
-
-    if vq.size(0) > 0:
-        seq_logproba = quiz_machine.models_logprobas(
-            models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        ) + quiz_machine.models_logprobas(
-            models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-
-        probas = seq_logproba.exp()
-
-        comments = []
-
-        for l in seq_logproba:
-            comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
-
-        filename = f"culture_c_quiz_{n_epoch:04d}.png"
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir, filename, vq, comments=comments
-        )
-
-
-######################################################################
-
-# The generator is very similar to a "solving GPT" except that it
-# deals with quizzes prologued with one token per solving GPT that
-# indicates if the said model solves it or not.
-#
-# There are three levels of solving 0->proba<=proba_not_understands,
-# 2->proba>=proba_understands and 1 otherwise.
-
-
-def generate_c_quizzes_with_generator(generator, quiz_machine, nb):
-    generator.to(main_device)
-
-    struct = ("A", "f_A", "B", "f_B")
-
-    c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct)
-    ar_mask = quiz_machine.make_quiz_mask(c_quizzes, struct, (1, 1, 1, 1))
-
-    i = F.one_hot(
-        torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
-        num_classes=args.nb_gpts,
-    )
-
-    prologs_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i)
-    prologs_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prologs_c_quizzes.size(1))
-
-    prologued_c_quizzes = torch.cat([prologs_c_quizzes, c_quizzes], dim=1).to(
-        main_device
-    )
-    prologued_ar_mask = torch.cat([prologs_ar_mask, ar_mask], dim=1).to(main_device)
-
-    seq_logproba = torch.zeros(
-        prologued_c_quizzes.size(0), device=prologued_c_quizzes.device
-    )
-
-    generator.temperature = args.temperature_hot
-
-    with torch.autograd.no_grad():
-        t = generator.training
-        generator.eval()
-
-        one_batch_masked_inplace_autoregression(
-            generator,
-            prologued_c_quizzes,
-            prologued_ar_mask,
-            seq_logproba,
-            deterministic_synthesis=False,
+            f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
         )
 
-        generator.train(t)
-
-    generator.reset_transformations()
-
-    prologued_c_quizzes = (
-        prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
-    )
-
-    c_quizzes = prologued_c_quizzes[:, prologs_c_quizzes.size(1) :]
-
-    return c_quizzes.to("cpu"), prologs_c_quizzes.to("cpu")
-
-
-def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0):
-    samples = []
-
-    for _ in range(args.nb_train_samples // args.batch_size):
-        while sum([x.size(0) for x in samples]) < args.batch_size:
-            # Generate a bunch of quizzes
-
-            if torch.rand(1).item() <= fraction_w_quizzes:
-                # Either we start with the world quizzes
-                c_quizzes = quiz_machine.problem.generate_w_quizzes(
-                    args.batch_size, progress_bar=False
-                )
-            else:
-                # Or we use the generator itself to generate them
-                c_quizzes, _ = generate_c_quizzes_with_generator(
-                    generator, quiz_machine, args.batch_size
-                )
-
-            # We remove the trivial ones
-            to_keep = quiz_machine.problem.trivial(c_quizzes) == False
-            c_quizzes = c_quizzes[to_keep]
-
-            # If there are remaining ones, we compute the true prolog
-            # that indicates how the GPTs solve it
-
-            if c_quizzes.size(0) > 0:
-                seq_logproba = quiz_machine.models_logprobas(
-                    models,
-                    c_quizzes,
-                    ("A", "f_A", "B", "f_B"),
-                    (0, 0, 0, 1),
-                    (0, 0, 1, 0),
-                ) + quiz_machine.models_logprobas(
-                    models,
-                    c_quizzes,
-                    ("f_A", "A", "f_B", "B"),
-                    (0, 0, 0, 1),
-                    (0, 0, 1, 0),
-                )
-
-                probas = seq_logproba.exp()
-
-                u0 = probas <= args.proba_not_understands
-                u2 = probas >= args.proba_understands
-                u1 = (u0 | u2) == False
-
-                prologs = (
-                    (u0.long() * token_prolog_0)
-                    + (u1.long() * token_prolog_1)
-                    + (u2.long() * token_prolog_2)
-                )
-
-                prologued_c_quizzes = torch.cat([prologs, c_quizzes], dim=1)
-
-                # nb_u2 = u2.long().sum(dim=1)
-                # nb_u0 = u0.long().sum(dim=1)
-                # prologued_c_quizzes = prologued_c_quizzes[(nb_u2 >= 1) & (nb_u0 >= 1)]
-
-                if prologued_c_quizzes.size(0) > 0:
-                    samples.append(prologued_c_quizzes)
-
-        # Now we yield a batch
-
-        x = torch.cat(samples, dim=0)
-        samples = [x[args.batch_size :]]
-
-        yield x[: args.batch_size]
-
-
-def one_generator_epoch(
-    generator, quiz_machine, models, fraction_w_quizzes, local_device=main_device
-):
-    model.to(local_device).train()
-
-    optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
-
-    nb_train_samples, acc_train_loss = 0, 0.0
-
-    src = batches_for_generator(
-        generator=generator,
-        quiz_machine=quiz_machine,
-        models=models,
-        fraction_w_quizzes=fraction_w_quizzes,
-    )
-
-    for input in tqdm.tqdm(
-        src,
-        dynamic_ncols=True,
-        desc="training",
-        total=args.nb_train_samples // args.batch_size,
-    ):
-        input = input.to(local_device)
-
-        if nb_train_samples % args.batch_size == 0:
-            optimizer.zero_grad()
-
-        targets = input
-
-        output = generator(mygpt.BracketedSequence(input)).x
-        loss = F.cross_entropy(output.transpose(1, 2), targets)
-        acc_train_loss += loss.item() * input.size(0)
-        nb_train_samples += input.size(0)
-
-        loss.backward()
-
-        if nb_train_samples % args.batch_size == 0:
-            optimizer.step()
-
-    train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
-    log_string(f"train_perplexity {n_epoch} generator - {train_perplexity}")
-
-    generator.to(main_device)
-
-
-######################################################################
-
-
-def train_complexifier(model_gen, model_pred1, model_pred2):
-    samples = []
-    perf = []
+    for model in models:
+        new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0)
 
-    optimizer = torch.optim.Adam(model_gen.parameters(), lr=args.learning_rate)
+        if new_bag.size(0) > 0:
+            n = (new_bag.size(0) * nb_for_train) // (nb_for_train + nb_for_test)
+            if n > 0:
+                model.train_c_quiz_bags.append(new_bag[:n])
+            if n < new_bag.size(0):
+                model.test_c_quiz_bags.append(new_bag[:n])
 
-    nb_train_samples, acc_train_loss = 0, 0.0
+            vq = new_bag[:128]
 
-    for n_epoch in range(args.nb_epochs):
-        for b in range(args.nb_train_samples // args.batch_size):
-            while sum([x.size(0) for x in samples]) < args.batch_size:
-                c_quizzes = quiz_machine.generate_c_quizzes(
-                    args.inference_batch_size,
-                    model_for_generation=model_gen,
-                    procedure=c_quizzes_procedure,
-                )
-                to_keep = quiz_machine.problem.trivial(c_quizzes) == False
-                c_quizzes = c_quizzes[to_keep]
-                if c_quizzes.size(0) > 0:
-                    seq_logproba = quiz_machine.models_logprobas(
-                        [model_pred1, model_pred2],
-                        c_quizzes,
-                        ("A", "f_A", "B", "f_B"),
-                        (0, 0, 0, 1),
-                    ) + quiz_machine.models_logprobas(
-                        [model_pred1, model_pred2],
-                        c_quizzes,
-                        ("f_A", "A", "f_B", "B"),
-                        (0, 0, 0, 1),
-                    )
-                    probas = seq_logproba.exp()
-                    to_keep = (probas[:, model_pred1.id] >= args.proba_understands) & (
-                        probas[:, model_pred2.id] <= args.proba_not_understands
-                    )
-                    log_string(
-                        f"generating {to_keep.long().sum()} / {c_quizzes.size(0)}"
-                    )
-                    c_quizzes = c_quizzes[to_keep]
-                    if c_quizzes.size(0):
-                        samples.append(c_quizzes)
-
-            log_string(f"full batch {sum([x.size(0) for x in samples])}")
-
-            x = torch.cat(samples, dim=0)
-
-            input = x[: args.batch_size]
-            samples = [x[args.batch_size :]]
-
-            # -------------------
-
-            seq_logproba = quiz_machine.models_logprobas(
-                [model_pred1, model_pred2],
-                input,
-                ("A", "f_A", "B", "f_B"),
-                (0, 0, 0, 1),
+            seq_logprobas = quiz_machine.models_logprobas(
+                models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
             ) + quiz_machine.models_logprobas(
-                [model_pred1, model_pred2],
-                input,
-                ("f_A", "A", "f_B", "B"),
-                (0, 0, 0, 1),
+                models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
             )
 
+            probas = seq_logprobas.exp()
+
             comments = []
 
-            for l in seq_logproba:
+            for l in seq_logprobas:
                 comments.append(
-                    f"proba {l[model_pred1.id].exp().item():.02f} {l[model_pred2.id].exp().item():.02f}"
+                    "proba " + " ".join([f"{x.exp().item():.02f}" for x in l])
                 )
 
-            filename = f"batch_{n_epoch:04d}_{b:04d}.png"
+            filename = f"culture_c_quiz_{n_epoch:04d}.png"
             quiz_machine.problem.save_quizzes_as_image(
-                args.result_dir, filename, input, comments=comments
+                args.result_dir, filename, vq, comments=comments
             )
-            log_string(f"wrote {filename}")
-
-            # ------------------------
-
-            input = input.to(main_device)
-
-            if nb_train_samples % args.batch_size == 0:
-                optimizer.zero_grad()
-
-            output = model_gen(mygpt.BracketedSequence(input)).x
-            loss = F.cross_entropy(output.transpose(1, 2), input)
-            acc_train_loss += loss.item() * input.size(0)
-            nb_train_samples += input.size(0)
-
-            loss.backward()
 
-            if nb_train_samples % args.batch_size == 0:
-                optimizer.step()
-
-        train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
-        log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
+        log_string(
+            f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
+        )
 
 
 ######################################################################
@@ -1072,7 +780,8 @@ for k in range(args.nb_gpts):
     ).to(main_device)
 
     model.id = k
-    model.c_quiz_bags = []
+    model.train_c_quiz_bags = []
+    model.test_c_quiz_bags = []
 
     if args.schedule_free:
         model.optimizer = schedulefree.AdamWScheduleFree(
@@ -1087,29 +796,6 @@ for k in range(args.nb_gpts):
 
 ######################################################################
 
-if args.test == "quant":
-    nb_bits = 8
-    for model in models:
-        model.trunk.insert(
-            12,
-            mygpt.CacheWrapper(
-                mygpt.RandomBypass(
-                    nn.Sequential(
-                        nn.Linear(args.dim_model, nb_bits),
-                        mygpt.BSQ(nb_bits),
-                        nn.Linear(nb_bits, args.dim_model),
-                    ),
-                    0.1,
-                )
-            ),
-        )
-
-        print(model)
-        exit(0)
-
-
-######################################################################
-
 current_epoch = 0
 
 if args.resume:
@@ -1170,153 +856,6 @@ if args.dirty_debug:
 
 ######################################################################
 
-if args.test == "tsne":
-    model = models[0]
-
-    quizzes = []
-    labels = []
-    nb_samples_per_task = 1000
-
-    for n, t in enumerate(args.grids_world_tasks.split(",")):
-        quizzes.append(
-            quiz_machine.problem.generate_w_quizzes(nb_samples_per_task, [t])
-        )
-        labels.append(torch.full((quizzes[-1].size(0),), n))
-
-    quizzes = torch.cat(quizzes, dim=0)
-    labels = torch.cat(labels, dim=0)
-
-    with torch.autograd.no_grad():
-        model.eval().to(main_device)
-        record = []
-        for input, targets in zip(
-            quizzes.split(args.batch_size), labels.split(args.batch_size)
-        ):
-            input = input.to(main_device)
-            bs = mygpt.BracketedSequence(input)
-            bs = mygpt.BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
-            bs = model.embedding(bs)
-            bs = model.trunk[args.nb_blocks // 2](bs)
-            record.append((bs.x.to("cpu"), targets))
-
-    x = torch.cat([x for x, y in record], dim=0).flatten(1)
-    y = torch.cat([y for x, y in record], dim=0)
-
-    print(f"{x.size()=} {y.size()=}")
-    # torch.save((x,y), "/tmp/embed.pth")
-    # exit(0)
-
-    from sklearn.manifold import TSNE
-
-    x_np = x.numpy()
-    z_np = TSNE(n_components=2, perplexity=50).fit_transform(x_np)
-    z = torch.from_numpy(z_np)
-
-    print(f"{z.size()=}")
-
-    with open("/tmp/result.dat", "w") as f:
-        for k in range(z.size(0)):
-            f.write(f"{y[k]} {z[k,0]} {z[k,1]}\n")
-
-    exit(0)
-
-######################################################################
-
-if args.test == "generator":
-    token_prolog_0 = vocabulary_size + 0
-    token_prolog_1 = vocabulary_size + 1
-    token_prolog_2 = vocabulary_size + 2
-    generator_vocabulary_size = vocabulary_size + 3
-
-    generator = mygpt.MyGPT(
-        vocabulary_size=generator_vocabulary_size,
-        dim_model=args.dim_model,
-        dim_keys=args.dim_keys,
-        dim_hidden=args.dim_hidden,
-        nb_heads=args.nb_heads,
-        nb_blocks=args.nb_blocks,
-        compute_attzero=compute_causal_attzero,
-        dropout=args.dropout,
-    ).to(main_device)
-
-    generator.main_test_accuracy = 0.0
-
-    filename = f"generator.pth"
-
-    try:
-        d = torch.load(os.path.join(args.result_dir, filename))
-        generator.load_state_dict(d[0])
-        generator.main_test_accuracy = d[1]
-        log_string(f"successfully loaded {filename}")
-    except FileNotFoundError:
-        log_string(f"cannot find {filename}")
-        pass
-
-    for n_epoch in range(args.nb_epochs):
-        one_generator_epoch(
-            generator,
-            quiz_machine=quiz_machine,
-            models=models,
-            fraction_w_quizzes=1 if n_epoch < 25 else 0.5,
-            local_device=main_device,
-        )
-
-        filename = f"generator.pth"
-        torch.save(
-            (generator.state_dict(), generator.main_test_accuracy),
-            os.path.join(args.result_dir, filename),
-        )
-        log_string(f"wrote {filename}")
-
-        c_quizzes, prologs = generate_c_quizzes_with_generator(
-            generator, quiz_machine, args.batch_size
-        )
-
-        seq_logproba = quiz_machine.models_logprobas(
-            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        ) + quiz_machine.models_logprobas(
-            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
-        )
-
-        probas = seq_logproba.exp()
-
-        u0 = probas <= args.proba_not_understands
-        u2 = probas >= args.proba_understands
-        u1 = (u0 | u2) == False
-
-        predicted_prologs = (
-            (u0.long() * token_prolog_0)
-            + (u1.long() * token_prolog_1)
-            + (u2.long() * token_prolog_2)
-        )
-
-        comments = []
-
-        nb_errors = (predicted_prologs != prologs).long().sum()
-        nb_total = prologs.numel()
-
-        log_string(f"generator_error {nb_errors} / {nb_total}")
-
-        def readable(prologs):
-            return (prologs == token_prolog_1) + 2 * (prologs == token_prolog_2)
-
-        for aa, ee, ff in zip(probas, readable(predicted_prologs), readable(prologs)):
-            sa = "prolog " + " ".join(
-                [f"{e.item()}/{f.item()}" for e, f in zip(ee, ff)]
-            )
-            sp = "proba " + " ".join([f"{p.item():.02f}" for p in aa])
-            comments.append(sa + "\n" + sp)
-
-        filename = f"generator_batch_{n_epoch:04d}.png"
-        quiz_machine.problem.save_quizzes_as_image(
-            args.result_dir, filename, c_quizzes, comments=comments
-        )
-        log_string(f"wrote {filename}")
-
-    exit(0)
-
-######################################################################
-
 for n_epoch in range(current_epoch, args.nb_epochs):
     state = {"current_epoch": n_epoch}
     filename = "state.pth"
@@ -1336,8 +875,8 @@ for n_epoch in range(current_epoch, args.nb_epochs):
         record_new_c_quizzes(
             models,
             quiz_machine,
-            nb_for_train=args.nb_new_c_quizzes_for_train,
-            nb_for_test=args.nb_new_c_quizzes_for_test,
+            args.nb_new_c_quizzes_for_train,
+            args.nb_new_c_quizzes_for_test,
         )
 
         filename = "c_quizzes.pth"
index b2287b8..1fe2e94 100755 (executable)
@@ -28,7 +28,7 @@ def one_batch_masked_inplace_autoregression(
     model,
     input,
     ar_mask,
-    acc_seq_logproba,
+    acc_seq_logprobas,
     deterministic_synthesis=False,
 ):
     if input.size(0) == 0:
@@ -53,7 +53,7 @@ def one_batch_masked_inplace_autoregression(
 
         all_n = torch.arange(t_next.size(0))
 
-        acc_seq_logproba += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
+        acc_seq_logprobas += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
 
         input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
 
@@ -107,7 +107,7 @@ class QuizMachine:
         model,
         input,
         ar_mask,
-        seq_logproba,
+        seq_logprobas,
         progress_bar_desc=None,
     ):
         assert input.size() == ar_mask.size()
@@ -115,7 +115,7 @@ class QuizMachine:
         batches = zip(
             input.split(self.batch_size),
             ar_mask.split(self.batch_size),
-            seq_logproba.split(self.batch_size),
+            seq_logprobas.split(self.batch_size),
         )
 
         if progress_bar_desc is not None:
@@ -130,12 +130,12 @@ class QuizMachine:
             t = model.training
             model.eval()
 
-            for input, ar_mask, seq_logproba in batches:
+            for input, ar_mask, seq_logprobas in batches:
                 one_batch_masked_inplace_autoregression(
                     model=model,
                     input=input,
                     ar_mask=ar_mask,
-                    acc_seq_logproba=seq_logproba,
+                    acc_seq_logprobas=seq_logprobas,
                     deterministic_synthesis=False,
                 )
 
@@ -143,9 +143,9 @@ class QuizMachine:
 
     ######################################################################
 
-    def data_input(self, model, nb_samples):
-        if len(model.c_quiz_bags) > 0:
-            c_quizzes = torch.cat(model.c_quiz_bags, dim=0)
+    def data_input(self, nb_samples, c_quiz_bags):
+        if len(c_quiz_bags) > 0:
+            c_quizzes = torch.cat(c_quiz_bags, dim=0)
 
             if c_quizzes.size(0) > nb_samples // 2:
                 i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
@@ -191,23 +191,23 @@ class QuizMachine:
         ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
         result = quizzes * (1 - ar_mask)
 
-        seq_logproba = torch.zeros(quizzes.size(0), device=self.device)
+        seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
 
         self.autoregression(
             model=model,
             input=result,
             ar_mask=ar_mask,
-            seq_logproba=seq_logproba,
+            seq_logprobas=seq_logprobas,
             progress_bar_desc="accuracy",
         )
 
         correct = (result == quizzes).min(dim=1).values.long()
 
-        result = result.to("cpu")
-        correct = correct.to("cpu")
-        seq_logproba = seq_logproba.to("cpu")
+        result = result.to("cpu")
+        correct = correct.to("cpu")
+        # seq_logprobas = seq_logprobas.to("cpu")
 
-        return result, correct, seq_logproba
+        return result, correct, seq_logprobas
 
     ######################################################################
 
@@ -226,6 +226,7 @@ class QuizMachine:
             result[i], correct[i], _ = self.predict(
                 model=model, quizzes=input[i], struct=struct, mask=mask_generate
             )
+
             predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[
                 None, :
             ]
@@ -288,7 +289,7 @@ class QuizMachine:
 
     def models_logprobas(
         self,
-        models_for_validation,
+        model,
         c_quizzes,
         struct,
         mask_loss,
@@ -300,9 +301,8 @@ class QuizMachine:
 
         c_quizzes = self.problem.reconfigure(c_quizzes, struct)
 
-        seq_logproba = torch.zeros(
+        seq_logprobas = torch.zeros(
             c_quizzes.size(0),
-            max([m.id for m in models_for_validation]) + 1,
             device=device,
         )
 
@@ -311,35 +311,32 @@ class QuizMachine:
         # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise
         # )
 
-        for model in models_for_validation:
-            with torch.autograd.no_grad():
-                t = model.training
-                model.eval()
-
-                for input, l in zip(
-                    c_quizzes.split(self.batch_size),
-                    seq_logproba.split(self.batch_size),
-                ):
-                    input = input.to(device)
-                    quiz_mask_loss = self.make_quiz_mask(
-                        input, struct=struct, mask=mask_loss
-                    )
-                    output = model(mygpt.BracketedSequence(input)).x
-                    l[:, model.id] = (
-                        -F.cross_entropy(
-                            output.transpose(1, 2), input, reduction="none"
-                        )
-                        * quiz_mask_loss
-                    ).sum(dim=1)
+        with torch.autograd.no_grad():
+            t = model.training
+            model.eval()
+
+            for input, l in zip(
+                c_quizzes.split(self.batch_size),
+                seq_logprobas.split(self.batch_size),
+            ):
+                input = input.to(device)
+                quiz_mask_loss = self.make_quiz_mask(
+                    input, struct=struct, mask=mask_loss
+                )
+                output = model(mygpt.BracketedSequence(input)).x
+                l[...] = (
+                    -F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+                    * quiz_mask_loss
+                ).sum(dim=1)
 
-                model.train(t)
+            model.train(t)
 
-        return seq_logproba.to("cpu")
+        return seq_logprobas.to("cpu")
 
     ######################################################################
 
     def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None):
-        seq_logproba = torch.zeros(nb, device=self.device)
+        seq_logprobas = torch.zeros(nb, device=self.device)
 
         c_quizzes = None
 
@@ -358,7 +355,7 @@ class QuizMachine:
                 model=model_for_generation,
                 input=c_quizzes,
                 ar_mask=self.make_quiz_mask(c_quizzes, s, m),
-                seq_logproba=seq_logproba,
+                seq_logprobas=seq_logprobas,
             )
 
             model_for_generation.reset_transformations()