######################################################################
 
 def autoregression(
-        model,
+        model, batch_size,
         nb_samples, nb_tokens_to_generate, starting_input = None,
         device = torch.device('cpu')
 ):
         first = starting_input.size(1)
         results = torch.cat((starting_input, results), 1)
 
-    for input in results.split(args.batch_size):
+    for input in results.split(batch_size):
         for s in tqdm.tqdm(range(first, input.size(1)), desc = 'synth'):
             output = model(input)
             logits = output[:, s]
         return 256
 
     def produce_results(self, n_epoch, model, nb_samples = 64):
-        results = autoregression(model, nb_samples, 28 * 28, device = self.device)
+        results = autoregression(model, self.batch_size, nb_samples, 28 * 28, device = self.device)
         image_name = f'result_mnist_{n_epoch:04d}.png'
         torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255.,
                                      image_name, nrow = 16, pad_value = 0.8)