From ee037c3687ac5a5f4ea5cd745f3833feaeb9e071 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 19 Sep 2024 13:14:57 +0200 Subject: [PATCH] Update. --- main.py | 142 +++++++--------- quiz_machine.py | 443 ------------------------------------------------ 2 files changed, 59 insertions(+), 526 deletions(-) delete mode 100755 quiz_machine.py diff --git a/main.py b/main.py index 750d1b1..0c40f95 100755 --- a/main.py +++ b/main.py @@ -14,9 +14,7 @@ from torch.nn import functional as F import ffutils import mygpt -import sky, grids, quiz_machine - -from quiz_machine import one_batch_masked_inplace_autoregression +import sky, grids import threading, subprocess @@ -254,26 +252,6 @@ assert args.nb_test_samples % args.batch_size == 0 ###################################################################### - -# ------------------------------------------------------ -alien_problem = grids.Grids( - max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, - chunk_size=100, - nb_threads=args.nb_threads, - tasks="symmetry", -) - -alien_quiz_machine = quiz_machine.QuizMachine( - problem=alien_problem, - batch_size=args.eval_batch_size, - result_dir=args.result_dir, - logger=log_string, - device=main_device, -) -# ------------------------------------------------------ - -###################################################################### - problem = grids.Grids( max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, chunk_size=100, @@ -284,19 +262,58 @@ problem = grids.Grids( if not args.resume: problem.save_some_examples(args.result_dir) -quiz_machine = quiz_machine.QuizMachine( - problem=problem, - batch_size=args.eval_batch_size, - result_dir=args.result_dir, - logger=log_string, - device=main_device, -) + +def pure_noise(nb, device): + r = problem.pure_noise(nb, device) + r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1) + return r + + +def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): + if c_quizzes is None: + quizzes = problem.generate_w_quizzes(nb_samples) + quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape( + quizzes.size(0), -1 + ) + nb_w_quizzes = quizzes.size(0) + nb_c_quizzes = 0 + else: + if c_quiz_multiplier > 1: + n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0)) + body = c_quizzes.repeat(n, 1) + if n < c_quiz_multiplier: + tail = c_quizzes[ + torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)] + ] + c_quizzes = torch.cat([body, tail], dim=0) + else: + c_quizzes = body + + if c_quizzes.size(0) > nb_samples // 2: + i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] + c_quizzes = c_quizzes[i] + + w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0)) + w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape( + w_quizzes.size(0), -1 + ) + quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) + nb_w_quizzes = w_quizzes.size(0) + nb_c_quizzes = c_quizzes.size(0) + + i = torch.randperm(quizzes.size(0), device=quizzes.device) + quizzes = quizzes[i].contiguous() + + log_string(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}") + + return quizzes + ###################################################################### log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}") -vocabulary_size = quiz_machine.vocabulary_size() +vocabulary_size = problem.nb_token_values log_string(f"vocabulary_size {vocabulary_size}") @@ -344,7 +361,7 @@ def add_noise_imt(imt_set): """Replace every component of the input by a random value with probability args.proba_prompt_noise.""" input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2] - noise = quiz_machine.pure_noise(input.size(0), input.device) + noise = pure_noise(input.size(0), input.device) change = (1 - masks) * ( torch.rand(input.size(), device=input.device) < args.proba_prompt_noise ).long() @@ -432,7 +449,7 @@ def samples_for_generation_imt(input): proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t mask_erased = (r <= proba_erased[:, None]).long() - noise = quiz_machine.pure_noise(nb, input.device) + noise = pure_noise(nb, input.device) targets = input input = (1 - mask_erased) * input + mask_erased * noise masks = input.new_full(input.size(), 1) @@ -456,7 +473,7 @@ def ae_generate(model, nb, local_device=main_device): # mini-batches second so that we keep only the samples that have # not stabilized - all_input = quiz_machine.pure_noise(nb, local_device) + all_input = pure_noise(nb, local_device) all_masks = all_input.new_full(all_input.size(), 1) all_changed = torch.full((all_input.size(0),), True, device=all_input.device) @@ -499,7 +516,7 @@ def ae_generate(model, nb, local_device=main_device): def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True): - quizzes = quiz_machine.quiz_set( + quizzes = quiz_set( args.nb_train_samples if train else args.nb_test_samples, c_quizzes, args.c_quiz_multiplier, @@ -572,22 +589,18 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): # Save some original world quizzes and the full prediction (the four grids) - quizzes = quiz_machine.quiz_set(25, c_quizzes, args.c_quiz_multiplier).to( - local_device - ) - quiz_machine.problem.save_quizzes_as_image( + quizzes = quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device) + problem.save_quizzes_as_image( args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes ) result = predict_full(model=model, input=quizzes, local_device=local_device) - quiz_machine.problem.save_quizzes_as_image( + problem.save_quizzes_as_image( args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result ) # Save some images of the prediction results - quizzes = quiz_machine.quiz_set( - args.nb_test_samples, c_quizzes, args.c_quiz_multiplier - ) + quizzes = quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier) imt_set = samples_for_prediction_imt(quizzes.to(local_device)) result = ae_predict(model, imt_set, local_device=local_device).to("cpu") masks = imt_set[:, 1].to("cpu") @@ -598,7 +611,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): ] predicted_parts = correct_parts.abs() - quiz_machine.problem.save_quizzes_as_image( + problem.save_quizzes_as_image( args.result_dir, f"culture_prediction_{n_epoch}_{model.id}.png", quizzes=result[:128], @@ -618,7 +631,7 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): # Save some images of the ex nihilo generation of the four grids result = ae_generate(model, 150, local_device=local_device).to("cpu") - quiz_machine.problem.save_quizzes_as_image( + problem.save_quizzes_as_image( args.result_dir, f"culture_generation_{n_epoch}_{model.id}.png", quizzes=result[:128], @@ -785,7 +798,7 @@ def save_quiz_image(models, c_quizzes, filename, local_device=main_device): comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)] - quiz_machine.problem.save_quizzes_as_image( + problem.save_quizzes_as_image( args.result_dir, filename, quizzes=c_quizzes, @@ -837,43 +850,6 @@ nb_parameters = sum(p.numel() for p in models[0].parameters()) log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") -###################################################################### - -if args.quizzes is not None: - with open(args.quizzes, "r") as file: - txt = file.read() - - quizzes = quiz_machine.problem.text2quiz(txt) - - record = [] - - quizzes = quizzes.to(main_device) - for model in models: - log_string(f"processing {model.id} {args.quizzes}") - for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]: - mask_generate = quiz_machine.make_quiz_mask( - quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad - ) - result = ae_generate(model, (1 - mask_generate) * quizzes, mask_generate) - record.append(result) - - result = torch.cat(record, dim=0) - - filename = "result.png" - - quiz_machine.problem.save_quizzes_as_image( - args.result_dir, - filename, - quizzes=result, - delta=True, - nrow=8, - ) - - log_string(f"wrote {filename}") - - exit(0) - - ###################################################################### c_quizzes = None diff --git a/quiz_machine.py b/quiz_machine.py deleted file mode 100755 index 72f1d16..0000000 --- a/quiz_machine.py +++ /dev/null @@ -1,443 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, os, tqdm, warnings, sys - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -import mygpt -from mygpt import BracketedSequence - -import threading - -###################################################################### - -# ar_mask is a tensor with 0s and 1s, of same shape as input, with -# 1s where tokens should be generated. The others are kept -# unchanged. - - -def one_batch_masked_inplace_autoregression( - model, - input, - ar_mask, - acc_seq_logprobas, - deterministic_synthesis=False, -): - if input.size(0) == 0: - return - - to_generate = (ar_mask.sum(0) > 0).nonzero() - - if to_generate.min() > 0: - model( - BracketedSequence(input, 0, to_generate.min()) - ) # Needed to initialize the model's cache - for s in range(to_generate.min(), to_generate.max() + 1): - output = model(BracketedSequence(input, s, 1)).x - - logits = output[:, s] - - if deterministic_synthesis: - t_next = logits.argmax(-1) - else: - dist = torch.distributions.categorical.Categorical(logits=logits) - t_next = dist.sample() - - all_n = torch.arange(t_next.size(0)) - - acc_seq_logprobas += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next] - - input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s] - - -###################################################################### - - -class QuizMachine: - def __init__( - self, - problem, - batch_size, - result_dir, - logger, - device=torch.device("cpu"), - ): - super().__init__() - - self.problem = problem - self.batch_size = batch_size - self.device = device - self.logger = logger - self.prompt_len = None - self.answer_len = None - - # quad_order, quad_generate, quad_noise, quad_loss - self.train_structures = [ - (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)), - # (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - # (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)), - ] - - self.test_structures = self.train_structures - - def vocabulary_size(self): - return self.problem.nb_token_values - - ###################################################################### - - def autoregression( - self, - model, - input, - ar_mask, - seq_logprobas, - progress_bar_desc=None, - ): - assert input.size() == ar_mask.size() - - batches = zip( - input.split(self.batch_size), - ar_mask.split(self.batch_size), - seq_logprobas.split(self.batch_size), - ) - - if progress_bar_desc is not None: - batches = tqdm.tqdm( - batches, - dynamic_ncols=True, - desc=progress_bar_desc, - total=(input.size(0) + self.batch_size - 1) // self.batch_size, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, ar_mask, seq_logprobas in batches: - one_batch_masked_inplace_autoregression( - model=model, - input=input, - ar_mask=ar_mask, - acc_seq_logprobas=seq_logprobas, - deterministic_synthesis=False, - ) - - model.train(t) - - ###################################################################### - - def data_input( - self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1, data_structures=None - ): - if data_structures is None: - data_structures = self.train_structures - - if len(c_quiz_bags) > 0: - c_quizzes = torch.cat(c_quiz_bags, dim=0) - - if c_quiz_multiplier > 1: - n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0)) - body = c_quizzes.repeat(n, 1) - if n < c_quiz_multiplier: - tail = c_quizzes[ - torch.randperm(c_quizzes.size(0))[ - : nb_samples // 2 - body.size(0) - ] - ] - c_quizzes = torch.cat([body, tail], dim=0) - else: - c_quizzes = body - - if c_quizzes.size(0) > nb_samples // 2: - i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] - c_quizzes = c_quizzes[i] - - w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0)) - quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - quizzes = self.problem.generate_w_quizzes(nb_samples) - - # shuffle - - i = torch.randperm(quizzes.size(0), device=quizzes.device) - quizzes = quizzes[i] - - # Re-order and inject noise - - quiz_mask_generate = quizzes.new_full(quizzes.size(), 1) - quiz_mask_loss = quizzes.new_full(quizzes.size(), 1) - order_ids = torch.randint(len(data_structures), (quizzes.size(0),)) - - for j, s in enumerate(data_structures): - quad_order, quad_generate, quad_noise, quad_loss = s - i = order_ids == j - quizzes[i] = self.problem.reconfigure(quizzes[i], quad_order=quad_order) - quiz_mask_generate[i] = self.make_quiz_mask( - quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate - ) - quiz_mask_loss[i] = self.make_quiz_mask( - quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss - ) - - return quizzes, quiz_mask_generate, quiz_mask_loss - - ###################################################################### - - def pure_noise(self, nb, device): - r = self.problem.pure_noise(nb, device) - r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1) - return r - - def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1): - if c_quizzes is None: - quizzes = self.problem.generate_w_quizzes(nb_samples) - quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape( - quizzes.size(0), -1 - ) - nb_w_quizzes = quizzes.size(0) - nb_c_quizzes = 0 - else: - if c_quiz_multiplier > 1: - n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0)) - body = c_quizzes.repeat(n, 1) - if n < c_quiz_multiplier: - tail = c_quizzes[ - torch.randperm(c_quizzes.size(0))[ - : nb_samples // 2 - body.size(0) - ] - ] - c_quizzes = torch.cat([body, tail], dim=0) - else: - c_quizzes = body - - if c_quizzes.size(0) > nb_samples // 2: - i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2] - c_quizzes = c_quizzes[i] - - w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0)) - w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape( - w_quizzes.size(0), -1 - ) - quizzes = torch.cat([w_quizzes, c_quizzes], dim=0) - nb_w_quizzes = w_quizzes.size(0) - nb_c_quizzes = c_quizzes.size(0) - - i = torch.randperm(quizzes.size(0), device=quizzes.device) - quizzes = quizzes[i].contiguous() - - logger(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}") - - return quizzes - - ###################################################################### - - def make_quiz_mask(self, quizzes, quad_order, quad_mask): - assert quad_order in [s for s, _, _, _ in self.train_structures] - return self.problem.make_quiz_mask( - quizzes, quad_order=quad_order, quad_mask=quad_mask - ) - - ###################################################################### - - def predict(self, model, quizzes, quad_order, quad_mask): - quizzes = quizzes.to(self.device) - ar_mask = self.make_quiz_mask( - quizzes=quizzes, quad_order=quad_order, quad_mask=quad_mask - ) - result = quizzes * (1 - ar_mask) - - seq_logprobas = torch.zeros(quizzes.size(0), device=self.device) - - self.autoregression( - model=model, - input=result, - ar_mask=ar_mask, - seq_logprobas=seq_logprobas, - progress_bar_desc="autoregression", - ) - - correct = (result == quizzes).min(dim=1).values.long() - - # result = result.to("cpu") - # correct = correct.to("cpu") - # seq_logprobas = seq_logprobas.to("cpu") - - return result, correct, seq_logprobas - - ###################################################################### - - def produce_results(self, n_epoch, model, input, result_dir): - input = input.to(self.device) - result = input.new(input.size()) - correct = input.new(input.size(0)) - predicted_parts = input.new(input.size(0), 4) - - nb = 0 - - # We consider all the configurations that we train for - for quad_order, quad_generate, _, _ in self.test_structures: - i = self.problem.indices_select(quizzes=input, quad_order=quad_order) - nb += i.long().sum() - result[i], correct[i], _ = self.predict( - model=model, quizzes=input[i], quad_order=quad_order, quad=quad_generate - ) - - predicted_parts[i] = torch.tensor(quad_generate, device=self.device)[ - None, : - ] - solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1 - correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long() - - assert nb == input.size(0) - - nb_correct = (correct == 1).long().sum() - nb_total = (correct != 0).long().sum() - self.logger( - f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}" - ) - - test_accuracy = (nb_correct / nb_total).item() - - ############################## - - correct_parts = predicted_parts * correct[:, None] - - result = result[:128] - predicted_parts = predicted_parts[:128] - correct_parts = correct_parts[:128] - - self.problem.save_quizzes_as_image( - result_dir, - f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png", - quizzes=result, - predicted_parts=predicted_parts, - correct_parts=correct_parts, - ) - - return test_accuracy - - ###################################################################### - - def randomize_configuations_inplace(self, quizzes, quad_orders): - r = torch.randint(len(quad_orders), (quizzes.size(0),), device=quizzes.device) - for c in range(len(quad_orders)): - quizzes[r == c] = self.problem.reconfigure( - quizzes[r == c], quad_order=quad_orders[c] - ) - - ###################################################################### - - def store_c_quizzes(self, new_c_quizzes, for_train=True): - with self.LOCK_C_QUIZZES: - if for_train: - self.train_c_quizzes.append(new_c_quizzes.to("cpu")) - else: - self.test_c_quizzes.append(new_c_quizzes.to("cpu")) - - def save_c_quizzes(self, filename): - torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) - - def load_c_quizzes(self, filename): - self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) - - ###################################################################### - - def models_logprobas( - self, - model, - c_quizzes, - quad_order, - quad_loss, - quad_noise=None, - temperature=1.0, - device=None, - ): - if device is None: - device = self.device - - c_quizzes = self.problem.reconfigure(c_quizzes, quad_order) - - seq_logprobas = torch.zeros( - c_quizzes.size(0), - device=device, - ) - - with torch.autograd.no_grad(): - t = model.training - model.eval() - - for input, l in zip( - c_quizzes.split(self.batch_size), - seq_logprobas.split(self.batch_size), - ): - input = input.to(device) - quiz_mask_loss = self.make_quiz_mask( - input, quad_order=quad_order, quad_mask=quad_loss - ) - output = model(mygpt.BracketedSequence(input)).x / temperature - l[...] = ( - -F.cross_entropy(output.transpose(1, 2), input, reduction="none") - * quiz_mask_loss - ).sum(dim=1) - - model.train(t) - - return seq_logprobas.to("cpu") - - ###################################################################### - - def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None): - seq_logprobas = torch.zeros(nb, device=self.device) - - c_quizzes = None - - for n_step, setup in enumerate(procedure): - quad_order, quad_generate, model_modifier = setup - if c_quizzes is None: - c_quizzes = self.problem.create_empty_quizzes(nb, quad_order) - c_quizzes = c_quizzes.to(self.device) - elif quad_order != pred_quad_order: - c_quizzes = self.problem.reconfigure(c_quizzes, quad_order) - pred_quad_order = quad_order - - if model_modifier is not None: - model_modifier(model_for_generation) - - self.autoregression( - model=model_for_generation, - input=c_quizzes, - ar_mask=self.make_quiz_mask( - quizzes=c_quizzes, quad_order=quad_order, quad_mask=quad_generate - ), - seq_logprobas=seq_logprobas, - progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}", - ) - - model_for_generation.reset_transformations() - - if recorder is not None: - x = c_quizzes.clone() - t = torch.tensor(quad_generate, device=x.device)[None, :].expand( - x.size(0), -1 - ) - recorder.append( - self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B")) - ) - - c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B")) - - return c_quizzes.to("cpu") - - ###################################################################### -- 2.39.5