From 6095c01cb91f79f2986edaf5fef2467faacbb8ac Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 19 Sep 2024 13:22:58 +0200 Subject: [PATCH] Update. --- tasks.py | 374 ------------------------------------------------------- 1 file changed, 374 deletions(-) delete mode 100755 tasks.py diff --git a/tasks.py b/tasks.py deleted file mode 100755 index 80ffdbb..0000000 --- a/tasks.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, os, tqdm, warnings - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -from mygpt import BracketedSequence - -###################################################################### - - -def masked_inplace_autoregression( - model, - batch_size, - input, - ar_mask, - summed_logits, - temperature, - deterministic_synthesis, - forbidden_tokens=None, - logit_biases=None, - progress_bar_desc="autoregression", - device=torch.device("cpu"), -): - assert input.size() == ar_mask.size() - - batches = zip(input.split(batch_size), ar_mask.split(batch_size)) - - if progress_bar_desc is not None: - batches = tqdm.tqdm( - batches, - dynamic_ncols=True, - desc=progress_bar_desc, - total=(input.size(0) + batch_size - 1) // batch_size, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, ar_mask in batches: - model.masked_inplace_autoregression( - input=input, - ar_mask=ar_mask, - summed_logits=summed_logits, - temperature=temperature, - deterministic_synthesis=deterministic_synthesis, - forbidden_tokens=forbidden_tokens, - forced_biases=logit_biases, - ) - - model.train(t) - - -###################################################################### - - -class Task: - def batches(self, split="train", nb_to_use=-1, desc=None): - pass - - def vocabulary_size(self): - pass - - def produce_results( - self, n_epoch, model, result_dir, logger, deterministic_synthesis - ): - pass - - -###################################################################### - -import world - - -class World(Task): - def save_image(self, input, result_dir, filename, logger): - img = world.seq2img(input.to("cpu"), self.height, self.width) - image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) - logger(f"wrote {image_name}") - - def make_ar_mask(self, input): - b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2 - return b.long()[None, :].expand_as(input) - - def __init__( - self, - nb_train_samples, - nb_test_samples, - batch_size, - result_dir=None, - logger=None, - device=torch.device("cpu"), - ): - super().__init__() - - self.batch_size = batch_size - self.device = device - self.height = 6 - self.width = 8 - - self.train_input = world.generate_seq( - nb_train_samples, height=self.height, width=self.width - ).to(device) - - self.test_input = world.generate_seq( - nb_test_samples, height=self.height, width=self.width - ).to(device) - - self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1 - - self.train_quizzes = [] - self.test_quizzes = [] - - if result_dir is not None: - self.save_image( - self.train_input[:72], result_dir, f"world_train.png", logger - ) - - def batches(self, split="train", desc=None): - assert split in {"train", "test"} - if split == "train": - input = self.train_input - quizzes = self.train_quizzes - else: - input = self.test_input - quizzes = self.test_quizzes - - if len(quizzes) > 0: - quizzes = torch.cat(quizzes, dim=0) - if quizzes.size(0) > input.size(0) // 2: - i = torch.randperm(input.size(0))[: input.size(0) // 2] - quizzes = quizzes[i] - - i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)] - input = input[i] - - self.nb_batch_samples_world = input.size(0) - self.nb_batch_samples_quizzes = quizzes.size(0) - - input = torch.cat([input, quizzes], dim=0) - else: - self.nb_batch_samples_world = input.size(0) - self.nb_batch_samples_quizzes = 0 - - # Shuffle - input = input[torch.randperm(input.size(0))] - - if desc is None: - desc = f"epoch-{split}" - for batch in tqdm.tqdm( - input.split(self.batch_size), dynamic_ncols=True, desc=desc - ): - yield batch - - def vocabulary_size(self): - return self.nb_codes - - def produce_results( - self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 - ): - def compute_accuracy(input, logger=None): - input = input[:nmax] - ar_mask = self.make_ar_mask(input) - result = input.clone() * (1 - ar_mask) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - summed_logits=None, - temperature=1.0, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc=None, - device=self.device, - ) - - nb_total, nb_correct = ( - input.size(0), - (input == result).long().min(dim=1).values.sum(), - ) - - return nb_total, nb_correct - - train_nb_total, train_nb_correct = compute_accuracy(self.train_input) - - logger( - f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%" - ) - - test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger) - - logger( - f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" - ) - - main_test_accuracy = test_nb_correct / test_nb_total - logger(f"main_test_accuracy {n_epoch} {main_test_accuracy}") - - ############################## - - input = self.test_input[:96] - ar_mask = self.make_ar_mask(input) - result = input.clone() * (1 - ar_mask) - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - summed_logits=None, - temperature=1.0, - deterministic_synthesis=deterministic_synthesis, - progress_bar_desc=None, - device=self.device, - ) - - self.save_image( - result[:72], - result_dir, - f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", - logger, - ) - - return main_test_accuracy - - def renew_samples(self, nb, for_train=True): - input = self.train_input if for_train else self.test_input - nb = min(nb, input.size(0)) - input[:-nb] = input[nb:].clone() - input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to( - self.device - ) - - def store_new_quizzes(self, new_quizzes, for_train=True): - if for_train: - self.train_quizzes.append(new_quizzes) - else: - self.test_quizzes.append(new_quizzes) - - def create_new_quizzes( - self, - n_epoch, - result_dir, - logger, - nb, - model, - other_models, - desired_average_logits=None, - ): - ############################################################### - # Generate quizzes with model - - quizzes = torch.empty( - nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64 - ) - - ar_mask = torch.full(quizzes.size(), 1, device=self.device) - summed_logits = torch.empty(nb, device=self.device) - - temperature = 1 - d_temperature = 1 - - while True: - summed_logits[...] = 0 - - masked_inplace_autoregression( - model=model, - batch_size=self.batch_size, - input=quizzes, - ar_mask=ar_mask, - summed_logits=summed_logits, - temperature=temperature, - deterministic_synthesis=False, - progress_bar_desc="creating quizzes", - device=self.device, - ) - - average_logits = summed_logits.mean() - - logger(f"{average_logits=} {desired_average_logits=}") - - if desired_average_logits is None: - break - - # Oh man that's ugly - if average_logits < desired_average_logits * 1.1: - if d_temperature > 0: - d_temperature *= -0.5 - temperature += d_temperature - elif average_logits > desired_average_logits: - if d_temperature < 0: - d_temperature *= -0.5 - temperature += d_temperature - else: - break - - logger(f"changing temperature to {temperature}") - - ############################################################### - # Create the reverse quizzes - - l = self.height * self.width - direction = quizzes[:, l : l + 1] - direction = world.token_forward * ( - direction == world.token_backward - ) + world.token_backward * (direction == world.token_forward) - reverse_quizzes = torch.cat( - [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1 - ) - - ar_mask = self.make_ar_mask(quizzes) - - ############################################################### - # Check how many of the other models can solve them in both - # directions - - nb_correct = [] - - for m in other_models: - result = quizzes.clone() - - masked_inplace_autoregression( - model=m, - batch_size=self.batch_size, - input=result, - ar_mask=ar_mask, - summed_logits=None, - temperature=1.0, - deterministic_synthesis=True, - progress_bar_desc="solving quizzes", - device=self.device, - ) - - correct = (quizzes == result).long().min(dim=-1).values - - reverse_result = reverse_quizzes.clone() - - masked_inplace_autoregression( - model=m, - batch_size=self.batch_size, - input=reverse_result, - ar_mask=ar_mask, - summed_logits=None, - temperature=1.0, - deterministic_synthesis=True, - progress_bar_desc="solving reversed quizzes", - device=self.device, - ) - - reverse_correct = ( - (reverse_quizzes == reverse_result).long().min(dim=-1).values - ) - - nb_correct.append((correct * reverse_correct)[None, :]) - - nb_correct = torch.cat(nb_correct, dim=0) - - # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat") - # with open(filename, "w") as f: - # for k in nb_correct: - # f.write(f"{k}\n") - - return quizzes, nb_correct.sum(dim=0), summed_logits.mean() -- 2.39.5