Update.
authorFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 11:22:58 +0000 (13:22 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Thu, 19 Sep 2024 11:22:58 +0000 (13:22 +0200)
tasks.py [deleted file]

diff --git a/tasks.py b/tasks.py
deleted file mode 100755 (executable)
index 80ffdbb..0000000
--- a/tasks.py
+++ /dev/null
@@ -1,374 +0,0 @@
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, os, tqdm, warnings
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-from mygpt import BracketedSequence
-
-######################################################################
-
-
-def masked_inplace_autoregression(
-    model,
-    batch_size,
-    input,
-    ar_mask,
-    summed_logits,
-    temperature,
-    deterministic_synthesis,
-    forbidden_tokens=None,
-    logit_biases=None,
-    progress_bar_desc="autoregression",
-    device=torch.device("cpu"),
-):
-    assert input.size() == ar_mask.size()
-
-    batches = zip(input.split(batch_size), ar_mask.split(batch_size))
-
-    if progress_bar_desc is not None:
-        batches = tqdm.tqdm(
-            batches,
-            dynamic_ncols=True,
-            desc=progress_bar_desc,
-            total=(input.size(0) + batch_size - 1) // batch_size,
-        )
-
-    with torch.autograd.no_grad():
-        t = model.training
-        model.eval()
-
-        for input, ar_mask in batches:
-            model.masked_inplace_autoregression(
-                input=input,
-                ar_mask=ar_mask,
-                summed_logits=summed_logits,
-                temperature=temperature,
-                deterministic_synthesis=deterministic_synthesis,
-                forbidden_tokens=forbidden_tokens,
-                forced_biases=logit_biases,
-            )
-
-        model.train(t)
-
-
-######################################################################
-
-
-class Task:
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        pass
-
-    def vocabulary_size(self):
-        pass
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        pass
-
-
-######################################################################
-
-import world
-
-
-class World(Task):
-    def save_image(self, input, result_dir, filename, logger):
-        img = world.seq2img(input.to("cpu"), self.height, self.width)
-        image_name = os.path.join(result_dir, filename)
-        torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
-        logger(f"wrote {image_name}")
-
-    def make_ar_mask(self, input):
-        b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
-        return b.long()[None, :].expand_as(input)
-
-    def __init__(
-        self,
-        nb_train_samples,
-        nb_test_samples,
-        batch_size,
-        result_dir=None,
-        logger=None,
-        device=torch.device("cpu"),
-    ):
-        super().__init__()
-
-        self.batch_size = batch_size
-        self.device = device
-        self.height = 6
-        self.width = 8
-
-        self.train_input = world.generate_seq(
-            nb_train_samples, height=self.height, width=self.width
-        ).to(device)
-
-        self.test_input = world.generate_seq(
-            nb_test_samples, height=self.height, width=self.width
-        ).to(device)
-
-        self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
-        self.train_quizzes = []
-        self.test_quizzes = []
-
-        if result_dir is not None:
-            self.save_image(
-                self.train_input[:72], result_dir, f"world_train.png", logger
-            )
-
-    def batches(self, split="train", desc=None):
-        assert split in {"train", "test"}
-        if split == "train":
-            input = self.train_input
-            quizzes = self.train_quizzes
-        else:
-            input = self.test_input
-            quizzes = self.test_quizzes
-
-        if len(quizzes) > 0:
-            quizzes = torch.cat(quizzes, dim=0)
-            if quizzes.size(0) > input.size(0) // 2:
-                i = torch.randperm(input.size(0))[: input.size(0) // 2]
-                quizzes = quizzes[i]
-
-            i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
-            input = input[i]
-
-            self.nb_batch_samples_world = input.size(0)
-            self.nb_batch_samples_quizzes = quizzes.size(0)
-
-            input = torch.cat([input, quizzes], dim=0)
-        else:
-            self.nb_batch_samples_world = input.size(0)
-            self.nb_batch_samples_quizzes = 0
-
-        # Shuffle
-        input = input[torch.randperm(input.size(0))]
-
-        if desc is None:
-            desc = f"epoch-{split}"
-        for batch in tqdm.tqdm(
-            input.split(self.batch_size), dynamic_ncols=True, desc=desc
-        ):
-            yield batch
-
-    def vocabulary_size(self):
-        return self.nb_codes
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
-    ):
-        def compute_accuracy(input, logger=None):
-            input = input[:nmax]
-            ar_mask = self.make_ar_mask(input)
-            result = input.clone() * (1 - ar_mask)
-
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=result,
-                ar_mask=ar_mask,
-                summed_logits=None,
-                temperature=1.0,
-                deterministic_synthesis=deterministic_synthesis,
-                progress_bar_desc=None,
-                device=self.device,
-            )
-
-            nb_total, nb_correct = (
-                input.size(0),
-                (input == result).long().min(dim=1).values.sum(),
-            )
-
-            return nb_total, nb_correct
-
-        train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
-
-        logger(
-            f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
-        )
-
-        test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
-
-        logger(
-            f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
-        )
-
-        main_test_accuracy = test_nb_correct / test_nb_total
-        logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}")
-
-        ##############################
-
-        input = self.test_input[:96]
-        ar_mask = self.make_ar_mask(input)
-        result = input.clone() * (1 - ar_mask)
-
-        masked_inplace_autoregression(
-            model=model,
-            batch_size=self.batch_size,
-            input=result,
-            ar_mask=ar_mask,
-            summed_logits=None,
-            temperature=1.0,
-            deterministic_synthesis=deterministic_synthesis,
-            progress_bar_desc=None,
-            device=self.device,
-        )
-
-        self.save_image(
-            result[:72],
-            result_dir,
-            f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
-            logger,
-        )
-
-        return main_test_accuracy
-
-    def renew_samples(self, nb, for_train=True):
-        input = self.train_input if for_train else self.test_input
-        nb = min(nb, input.size(0))
-        input[:-nb] = input[nb:].clone()
-        input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
-            self.device
-        )
-
-    def store_new_quizzes(self, new_quizzes, for_train=True):
-        if for_train:
-            self.train_quizzes.append(new_quizzes)
-        else:
-            self.test_quizzes.append(new_quizzes)
-
-    def create_new_quizzes(
-        self,
-        n_epoch,
-        result_dir,
-        logger,
-        nb,
-        model,
-        other_models,
-        desired_average_logits=None,
-    ):
-        ###############################################################
-        # Generate quizzes with model
-
-        quizzes = torch.empty(
-            nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
-        )
-
-        ar_mask = torch.full(quizzes.size(), 1, device=self.device)
-        summed_logits = torch.empty(nb, device=self.device)
-
-        temperature = 1
-        d_temperature = 1
-
-        while True:
-            summed_logits[...] = 0
-
-            masked_inplace_autoregression(
-                model=model,
-                batch_size=self.batch_size,
-                input=quizzes,
-                ar_mask=ar_mask,
-                summed_logits=summed_logits,
-                temperature=temperature,
-                deterministic_synthesis=False,
-                progress_bar_desc="creating quizzes",
-                device=self.device,
-            )
-
-            average_logits = summed_logits.mean()
-
-            logger(f"{average_logits=} {desired_average_logits=}")
-
-            if desired_average_logits is None:
-                break
-
-            # Oh man that's ugly
-            if average_logits < desired_average_logits * 1.1:
-                if d_temperature > 0:
-                    d_temperature *= -0.5
-                temperature += d_temperature
-            elif average_logits > desired_average_logits:
-                if d_temperature < 0:
-                    d_temperature *= -0.5
-                temperature += d_temperature
-            else:
-                break
-
-            logger(f"changing temperature to {temperature}")
-
-        ###############################################################
-        # Create the reverse quizzes
-
-        l = self.height * self.width
-        direction = quizzes[:, l : l + 1]
-        direction = world.token_forward * (
-            direction == world.token_backward
-        ) + world.token_backward * (direction == world.token_forward)
-        reverse_quizzes = torch.cat(
-            [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
-        )
-
-        ar_mask = self.make_ar_mask(quizzes)
-
-        ###############################################################
-        # Check how many of the other models can solve them in both
-        # directions
-
-        nb_correct = []
-
-        for m in other_models:
-            result = quizzes.clone()
-
-            masked_inplace_autoregression(
-                model=m,
-                batch_size=self.batch_size,
-                input=result,
-                ar_mask=ar_mask,
-                summed_logits=None,
-                temperature=1.0,
-                deterministic_synthesis=True,
-                progress_bar_desc="solving quizzes",
-                device=self.device,
-            )
-
-            correct = (quizzes == result).long().min(dim=-1).values
-
-            reverse_result = reverse_quizzes.clone()
-
-            masked_inplace_autoregression(
-                model=m,
-                batch_size=self.batch_size,
-                input=reverse_result,
-                ar_mask=ar_mask,
-                summed_logits=None,
-                temperature=1.0,
-                deterministic_synthesis=True,
-                progress_bar_desc="solving reversed quizzes",
-                device=self.device,
-            )
-
-            reverse_correct = (
-                (reverse_quizzes == reverse_result).long().min(dim=-1).values
-            )
-
-            nb_correct.append((correct * reverse_correct)[None, :])
-
-        nb_correct = torch.cat(nb_correct, dim=0)
-
-        # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
-        # with open(filename, "w") as f:
-        # for k in nb_correct:
-        # f.write(f"{k}\n")
-
-        return quizzes, nb_correct.sum(dim=0), summed_logits.mean()