Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 11:14:57 +0000 (13:14 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 11:14:57 +0000 (13:14 +0200)
main.py
quiz_machine.py [deleted file]

diff --git a/main.py b/main.py
index 750d1b1..0c40f95 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -14,9 +14,7 @@ from torch.nn import functional as F
 import ffutils
 
 import mygpt
-import sky, grids, quiz_machine
-
-from quiz_machine import one_batch_masked_inplace_autoregression
+import sky, grids
 
 import threading, subprocess
 
@@ -254,26 +252,6 @@ assert args.nb_test_samples % args.batch_size == 0
 
 ######################################################################
 
-
-# ------------------------------------------------------
-alien_problem = grids.Grids(
-    max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
-    chunk_size=100,
-    nb_threads=args.nb_threads,
-    tasks="symmetry",
-)
-
-alien_quiz_machine = quiz_machine.QuizMachine(
-    problem=alien_problem,
-    batch_size=args.eval_batch_size,
-    result_dir=args.result_dir,
-    logger=log_string,
-    device=main_device,
-)
-# ------------------------------------------------------
-
-######################################################################
-
 problem = grids.Grids(
     max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
     chunk_size=100,
@@ -284,19 +262,58 @@ problem = grids.Grids(
 if not args.resume:
     problem.save_some_examples(args.result_dir)
 
-quiz_machine = quiz_machine.QuizMachine(
-    problem=problem,
-    batch_size=args.eval_batch_size,
-    result_dir=args.result_dir,
-    logger=log_string,
-    device=main_device,
-)
+
+def pure_noise(nb, device):
+    r = problem.pure_noise(nb, device)
+    r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1)
+    return r
+
+
+def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
+    if c_quizzes is None:
+        quizzes = problem.generate_w_quizzes(nb_samples)
+        quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
+            quizzes.size(0), -1
+        )
+        nb_w_quizzes = quizzes.size(0)
+        nb_c_quizzes = 0
+    else:
+        if c_quiz_multiplier > 1:
+            n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+            body = c_quizzes.repeat(n, 1)
+            if n < c_quiz_multiplier:
+                tail = c_quizzes[
+                    torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+                ]
+                c_quizzes = torch.cat([body, tail], dim=0)
+            else:
+                c_quizzes = body
+
+        if c_quizzes.size(0) > nb_samples // 2:
+            i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+            c_quizzes = c_quizzes[i]
+
+        w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+        w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape(
+            w_quizzes.size(0), -1
+        )
+        quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+        nb_w_quizzes = w_quizzes.size(0)
+        nb_c_quizzes = c_quizzes.size(0)
+
+    i = torch.randperm(quizzes.size(0), device=quizzes.device)
+    quizzes = quizzes[i].contiguous()
+
+    log_string(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
+
+    return quizzes
+
 
 ######################################################################
 
 log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
 
-vocabulary_size = quiz_machine.vocabulary_size()
+vocabulary_size = problem.nb_token_values
 
 log_string(f"vocabulary_size {vocabulary_size}")
 
@@ -344,7 +361,7 @@ def add_noise_imt(imt_set):
     """Replace every component of the input by a random value with
     probability args.proba_prompt_noise."""
     input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
-    noise = quiz_machine.pure_noise(input.size(0), input.device)
+    noise = pure_noise(input.size(0), input.device)
     change = (1 - masks) * (
         torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
     ).long()
@@ -432,7 +449,7 @@ def samples_for_generation_imt(input):
     proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
     mask_erased = (r <= proba_erased[:, None]).long()
 
-    noise = quiz_machine.pure_noise(nb, input.device)
+    noise = pure_noise(nb, input.device)
     targets = input
     input = (1 - mask_erased) * input + mask_erased * noise
     masks = input.new_full(input.size(), 1)
@@ -456,7 +473,7 @@ def ae_generate(model, nb, local_device=main_device):
     # mini-batches second so that we keep only the samples that have
     # not stabilized
 
-    all_input = quiz_machine.pure_noise(nb, local_device)
+    all_input = pure_noise(nb, local_device)
     all_masks = all_input.new_full(all_input.size(), 1)
     all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
 
@@ -499,7 +516,7 @@ def ae_generate(model, nb, local_device=main_device):
 
 
 def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
-    quizzes = quiz_machine.quiz_set(
+    quizzes = quiz_set(
         args.nb_train_samples if train else args.nb_test_samples,
         c_quizzes,
         args.c_quiz_multiplier,
@@ -572,22 +589,18 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
 
     # Save some original world quizzes and the full prediction (the four grids)
 
-    quizzes = quiz_machine.quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(
-        local_device
-    )
-    quiz_machine.problem.save_quizzes_as_image(
+    quizzes = quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device)
+    problem.save_quizzes_as_image(
         args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
     )
     result = predict_full(model=model, input=quizzes, local_device=local_device)
-    quiz_machine.problem.save_quizzes_as_image(
+    problem.save_quizzes_as_image(
         args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
     )
 
     # Save some images of the prediction results
 
-    quizzes = quiz_machine.quiz_set(
-        args.nb_test_samples, c_quizzes, args.c_quiz_multiplier
-    )
+    quizzes = quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
     imt_set = samples_for_prediction_imt(quizzes.to(local_device))
     result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
     masks = imt_set[:, 1].to("cpu")
@@ -598,7 +611,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     ]
     predicted_parts = correct_parts.abs()
 
-    quiz_machine.problem.save_quizzes_as_image(
+    problem.save_quizzes_as_image(
         args.result_dir,
         f"culture_prediction_{n_epoch}_{model.id}.png",
         quizzes=result[:128],
@@ -618,7 +631,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device):
     # Save some images of the ex nihilo generation of the four grids
 
     result = ae_generate(model, 150, local_device=local_device).to("cpu")
-    quiz_machine.problem.save_quizzes_as_image(
+    problem.save_quizzes_as_image(
         args.result_dir,
         f"culture_generation_{n_epoch}_{model.id}.png",
         quizzes=result[:128],
@@ -785,7 +798,7 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device):
 
     comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
 
-    quiz_machine.problem.save_quizzes_as_image(
+    problem.save_quizzes_as_image(
         args.result_dir,
         filename,
         quizzes=c_quizzes,
@@ -837,43 +850,6 @@ nb_parameters = sum(p.numel() for p in models[0].parameters())
 log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
 
 
-######################################################################
-
-if args.quizzes is not None:
-    with open(args.quizzes, "r") as file:
-        txt = file.read()
-
-    quizzes = quiz_machine.problem.text2quiz(txt)
-
-    record = []
-
-    quizzes = quizzes.to(main_device)
-    for model in models:
-        log_string(f"processing {model.id} {args.quizzes}")
-        for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
-            mask_generate = quiz_machine.make_quiz_mask(
-                quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
-            )
-            result = ae_generate(model, (1 - mask_generate) * quizzes, mask_generate)
-            record.append(result)
-
-    result = torch.cat(record, dim=0)
-
-    filename = "result.png"
-
-    quiz_machine.problem.save_quizzes_as_image(
-        args.result_dir,
-        filename,
-        quizzes=result,
-        delta=True,
-        nrow=8,
-    )
-
-    log_string(f"wrote {filename}")
-
-    exit(0)
-
-
 ######################################################################
 
 c_quizzes = None
diff --git a/quiz_machine.py b/quiz_machine.py
deleted file mode 100755 (executable)
index 72f1d16..0000000
+++ /dev/null
@@ -1,443 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, os, tqdm, warnings, sys
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-import mygpt
-from mygpt import BracketedSequence
-
-import threading
-
-######################################################################
-
-# ar_mask is a tensor with 0s and 1s, of same shape as input, with
-# 1s where tokens should be generated. The others are kept
-# unchanged.
-
-
-def one_batch_masked_inplace_autoregression(
-    model,
-    input,
-    ar_mask,
-    acc_seq_logprobas,
-    deterministic_synthesis=False,
-):
-    if input.size(0) == 0:
-        return
-
-    to_generate = (ar_mask.sum(0) > 0).nonzero()
-
-    if to_generate.min() > 0:
-        model(
-            BracketedSequence(input, 0, to_generate.min())
-        )  # Needed to initialize the model's cache
-    for s in range(to_generate.min(), to_generate.max() + 1):
-        output = model(BracketedSequence(input, s, 1)).x
-
-        logits = output[:, s]
-
-        if deterministic_synthesis:
-            t_next = logits.argmax(-1)
-        else:
-            dist = torch.distributions.categorical.Categorical(logits=logits)
-            t_next = dist.sample()
-
-        all_n = torch.arange(t_next.size(0))
-
-        acc_seq_logprobas += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
-
-        input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
-
-
-######################################################################
-
-
-class QuizMachine:
-    def __init__(
-        self,
-        problem,
-        batch_size,
-        result_dir,
-        logger,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.problem = problem
-        self.batch_size = batch_size
-        self.device = device
-        self.logger = logger
-        self.prompt_len = None
-        self.answer_len = None
-
-        # quad_order, quad_generate, quad_noise, quad_loss
-        self.train_structures = [
-            (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-            (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-            (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-            (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
-            # (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-            # (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-            (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
-        ]
-
-        self.test_structures = self.train_structures
-
-    def vocabulary_size(self):
-        return self.problem.nb_token_values
-
-    ######################################################################
-
-    def autoregression(
-        self,
-        model,
-        input,
-        ar_mask,
-        seq_logprobas,
-        progress_bar_desc=None,
-    ):
-        assert input.size() == ar_mask.size()
-
-        batches = zip(
-            input.split(self.batch_size),
-            ar_mask.split(self.batch_size),
-            seq_logprobas.split(self.batch_size),
-        )
-
-        if progress_bar_desc is not None:
-            batches = tqdm.tqdm(
-                batches,
-                dynamic_ncols=True,
-                desc=progress_bar_desc,
-                total=(input.size(0) + self.batch_size - 1) // self.batch_size,
-            )
-
-        with torch.autograd.no_grad():
-            t = model.training
-            model.eval()
-
-            for input, ar_mask, seq_logprobas in batches:
-                one_batch_masked_inplace_autoregression(
-                    model=model,
-                    input=input,
-                    ar_mask=ar_mask,
-                    acc_seq_logprobas=seq_logprobas,
-                    deterministic_synthesis=False,
-                )
-
-            model.train(t)
-
-    ######################################################################
-
-    def data_input(
-        self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1, data_structures=None
-    ):
-        if data_structures is None:
-            data_structures = self.train_structures
-
-        if len(c_quiz_bags) > 0:
-            c_quizzes = torch.cat(c_quiz_bags, dim=0)
-
-            if c_quiz_multiplier > 1:
-                n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
-                body = c_quizzes.repeat(n, 1)
-                if n < c_quiz_multiplier:
-                    tail = c_quizzes[
-                        torch.randperm(c_quizzes.size(0))[
-                            : nb_samples // 2 - body.size(0)
-                        ]
-                    ]
-                    c_quizzes = torch.cat([body, tail], dim=0)
-                else:
-                    c_quizzes = body
-
-            if c_quizzes.size(0) > nb_samples // 2:
-                i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
-                c_quizzes = c_quizzes[i]
-
-            w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
-            quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
-        else:
-            quizzes = self.problem.generate_w_quizzes(nb_samples)
-
-        # shuffle
-
-        i = torch.randperm(quizzes.size(0), device=quizzes.device)
-        quizzes = quizzes[i]
-
-        # Re-order and inject noise
-
-        quiz_mask_generate = quizzes.new_full(quizzes.size(), 1)
-        quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
-        order_ids = torch.randint(len(data_structures), (quizzes.size(0),))
-
-        for j, s in enumerate(data_structures):
-            quad_order, quad_generate, quad_noise, quad_loss = s
-            i = order_ids == j
-            quizzes[i] = self.problem.reconfigure(quizzes[i], quad_order=quad_order)
-            quiz_mask_generate[i] = self.make_quiz_mask(
-                quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
-            )
-            quiz_mask_loss[i] = self.make_quiz_mask(
-                quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss
-            )
-
-        return quizzes, quiz_mask_generate, quiz_mask_loss
-
-    ######################################################################
-
-    def pure_noise(self, nb, device):
-        r = self.problem.pure_noise(nb, device)
-        r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1)
-        return r
-
-    def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1):
-        if c_quizzes is None:
-            quizzes = self.problem.generate_w_quizzes(nb_samples)
-            quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
-                quizzes.size(0), -1
-            )
-            nb_w_quizzes = quizzes.size(0)
-            nb_c_quizzes = 0
-        else:
-            if c_quiz_multiplier > 1:
-                n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
-                body = c_quizzes.repeat(n, 1)
-                if n < c_quiz_multiplier:
-                    tail = c_quizzes[
-                        torch.randperm(c_quizzes.size(0))[
-                            : nb_samples // 2 - body.size(0)
-                        ]
-                    ]
-                    c_quizzes = torch.cat([body, tail], dim=0)
-                else:
-                    c_quizzes = body
-
-            if c_quizzes.size(0) > nb_samples // 2:
-                i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
-                c_quizzes = c_quizzes[i]
-
-            w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
-            w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape(
-                w_quizzes.size(0), -1
-            )
-            quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
-            nb_w_quizzes = w_quizzes.size(0)
-            nb_c_quizzes = c_quizzes.size(0)
-
-        i = torch.randperm(quizzes.size(0), device=quizzes.device)
-        quizzes = quizzes[i].contiguous()
-
-        logger(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
-
-        return quizzes
-
-    ######################################################################
-
-    def make_quiz_mask(self, quizzes, quad_order, quad_mask):
-        assert quad_order in [s for s, _, _, _ in self.train_structures]
-        return self.problem.make_quiz_mask(
-            quizzes, quad_order=quad_order, quad_mask=quad_mask
-        )
-
-    ######################################################################
-
-    def predict(self, model, quizzes, quad_order, quad_mask):
-        quizzes = quizzes.to(self.device)
-        ar_mask = self.make_quiz_mask(
-            quizzes=quizzes, quad_order=quad_order, quad_mask=quad_mask
-        )
-        result = quizzes * (1 - ar_mask)
-
-        seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
-
-        self.autoregression(
-            model=model,
-            input=result,
-            ar_mask=ar_mask,
-            seq_logprobas=seq_logprobas,
-            progress_bar_desc="autoregression",
-        )
-
-        correct = (result == quizzes).min(dim=1).values.long()
-
-        # result = result.to("cpu")
-        # correct = correct.to("cpu")
-        # seq_logprobas = seq_logprobas.to("cpu")
-
-        return result, correct, seq_logprobas
-
-    ######################################################################
-
-    def produce_results(self, n_epoch, model, input, result_dir):
-        input = input.to(self.device)
-        result = input.new(input.size())
-        correct = input.new(input.size(0))
-        predicted_parts = input.new(input.size(0), 4)
-
-        nb = 0
-
-        # We consider all the configurations that we train for
-        for quad_order, quad_generate, _, _ in self.test_structures:
-            i = self.problem.indices_select(quizzes=input, quad_order=quad_order)
-            nb += i.long().sum()
-            result[i], correct[i], _ = self.predict(
-                model=model, quizzes=input[i], quad_order=quad_order, quad=quad_generate
-            )
-
-            predicted_parts[i] = torch.tensor(quad_generate, device=self.device)[
-                None, :
-            ]
-            solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
-            correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
-
-        assert nb == input.size(0)
-
-        nb_correct = (correct == 1).long().sum()
-        nb_total = (correct != 0).long().sum()
-        self.logger(
-            f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
-        )
-
-        test_accuracy = (nb_correct / nb_total).item()
-
-        ##############################
-
-        correct_parts = predicted_parts * correct[:, None]
-
-        result = result[:128]
-        predicted_parts = predicted_parts[:128]
-        correct_parts = correct_parts[:128]
-
-        self.problem.save_quizzes_as_image(
-            result_dir,
-            f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
-            quizzes=result,
-            predicted_parts=predicted_parts,
-            correct_parts=correct_parts,
-        )
-
-        return test_accuracy
-
-    ######################################################################
-
-    def randomize_configuations_inplace(self, quizzes, quad_orders):
-        r = torch.randint(len(quad_orders), (quizzes.size(0),), device=quizzes.device)
-        for c in range(len(quad_orders)):
-            quizzes[r == c] = self.problem.reconfigure(
-                quizzes[r == c], quad_order=quad_orders[c]
-            )
-
-    ######################################################################
-
-    def store_c_quizzes(self, new_c_quizzes, for_train=True):
-        with self.LOCK_C_QUIZZES:
-            if for_train:
-                self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
-            else:
-                self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
-
-    def save_c_quizzes(self, filename):
-        torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
-
-    def load_c_quizzes(self, filename):
-        self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
-
-    ######################################################################
-
-    def models_logprobas(
-        self,
-        model,
-        c_quizzes,
-        quad_order,
-        quad_loss,
-        quad_noise=None,
-        temperature=1.0,
-        device=None,
-    ):
-        if device is None:
-            device = self.device
-
-        c_quizzes = self.problem.reconfigure(c_quizzes, quad_order)
-
-        seq_logprobas = torch.zeros(
-            c_quizzes.size(0),
-            device=device,
-        )
-
-        with torch.autograd.no_grad():
-            t = model.training
-            model.eval()
-
-            for input, l in zip(
-                c_quizzes.split(self.batch_size),
-                seq_logprobas.split(self.batch_size),
-            ):
-                input = input.to(device)
-                quiz_mask_loss = self.make_quiz_mask(
-                    input, quad_order=quad_order, quad_mask=quad_loss
-                )
-                output = model(mygpt.BracketedSequence(input)).x / temperature
-                l[...] = (
-                    -F.cross_entropy(output.transpose(1, 2), input, reduction="none")
-                    * quiz_mask_loss
-                ).sum(dim=1)
-
-            model.train(t)
-
-        return seq_logprobas.to("cpu")
-
-    ######################################################################
-
-    def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None):
-        seq_logprobas = torch.zeros(nb, device=self.device)
-
-        c_quizzes = None
-
-        for n_step, setup in enumerate(procedure):
-            quad_order, quad_generate, model_modifier = setup
-            if c_quizzes is None:
-                c_quizzes = self.problem.create_empty_quizzes(nb, quad_order)
-                c_quizzes = c_quizzes.to(self.device)
-            elif quad_order != pred_quad_order:
-                c_quizzes = self.problem.reconfigure(c_quizzes, quad_order)
-            pred_quad_order = quad_order
-
-            if model_modifier is not None:
-                model_modifier(model_for_generation)
-
-            self.autoregression(
-                model=model_for_generation,
-                input=c_quizzes,
-                ar_mask=self.make_quiz_mask(
-                    quizzes=c_quizzes, quad_order=quad_order, quad_mask=quad_generate
-                ),
-                seq_logprobas=seq_logprobas,
-                progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}",
-            )
-
-            model_for_generation.reset_transformations()
-
-            if recorder is not None:
-                x = c_quizzes.clone()
-                t = torch.tensor(quad_generate, device=x.device)[None, :].expand(
-                    x.size(0), -1
-                )
-                recorder.append(
-                    self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B"))
-                )
-
-        c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
-
-        return c_quizzes.to("cpu")
-
-    ######################################################################