From: Francois Fleuret Date: Fri, 29 Apr 2022 11:58:55 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=95d8b6bc41a753f7a12b2a4cd047ea11cdc2054f;p=mygpt.git Update. --- diff --git a/main.py b/main.py new file mode 100755 index 0000000..a6940f1 --- /dev/null +++ b/main.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys, argparse, time, tqdm, itertools + +import torch, torchtext, torchvision +from torch import nn +from torch.nn import functional as F + +import mygpt + +###################################################################### + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +###################################################################### + +parser = argparse.ArgumentParser(description = 'My own GPT.') + +parser.add_argument('--log_filename', + type = str, default = 'train.log') + +parser.add_argument('--download', + type = bool, default = False) + +parser.add_argument('--seed', + type = int, default = 0) + +parser.add_argument('--nb_epochs', + type = int, default = 100) + +parser.add_argument('--batch_size', + type = int, default = 25) + +parser.add_argument('--data', + type = str, default = 'wiki103') + +parser.add_argument('--data_size', + type = int, default = -1) + +parser.add_argument('--optim', + type = str, default = 'adam') + +parser.add_argument('--learning_rate', + type = float, default = 1e-4) + +parser.add_argument('--dim_model', + type = int, default = 512) + +parser.add_argument('--dim_keys', + type = int, default = 64) + +parser.add_argument('--dim_hidden', + type = int, default = 2048) + +parser.add_argument('--nb_heads', + type = int, default = 8) + +parser.add_argument('--nb_blocks', + type = int, default = 12) + +parser.add_argument('--dropout', + type = float, default = 0.1) + +parser.add_argument('--synthesis_sampling', + type = bool, default = True) + +###################################################################### + +args = parser.parse_args() + +log_file = open(args.log_filename, 'w') + +if args.seed >= 0: + torch.manual_seed(args.seed) + +###################################################################### + +def log_string(s): + t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime()) + + if log_file is not None: + log_file.write(t + s + '\n') + log_file.flush() + + print(t + s) + sys.stdout.flush() + +for n in vars(args): + log_string(f'args.{n} {getattr(args, n)}') + +###################################################################### + +class Task: + def batches(self, split = 'train'): + pass + + def vocabulary_size(self): + pass + + def produce_results(self, n_epoch, model, nb_tokens = 50): + pass + +###################################################################### + +import picoclvr + +class TaskPicoCLVR(Task): + + def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')): + self.batch_size = batch_size + self.device = device + nb = args.data_size if args.data_size > 0 else 250000 + + descr = picoclvr.generate(nb, height = height, width = width) + descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in descr ]) + descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + + tokens = set() + for s in descr: + for t in s: tokens.add(t) + self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) + self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) + + t = [ [ self.token2id[u] for u in s ] for s in descr ] + data_input = torch.tensor(t, device = self.device) + + self.test_input = data_input[:nb // 5] + self.train_input = data_input[nb // 5:] + + def batches(self, split = 'train'): + assert split in { 'train', 'test' } + if split == 'train': + for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'): + yield batch + else: + for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'): + yield batch + + def vocabulary_size(self): + return len(self.token2id) + + def produce_results(self, n_epoch, model, nb_tokens = 50): + img = [ ] + nb_per_primer = 8 + + for primer in [ + 'red above green green top blue right of red ', + 'there is red there is yellow there is blue ', + 'red below yellow yellow below green green below blue red right yellow left green right blue left ', + 'green bottom yellow bottom green left of blue yellow right of blue blue top ', + ]: + + for k in range(nb_per_primer): + t_primer = primer.strip().split(' ') + t_generated = [ ] + + for j in range(nb_tokens): + t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] + input = torch.tensor(t, device = self.device) + output = model(input) + logits = output[0, -1] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t = dist.sample() + else: + t = logits.argmax() + t_generated.append(self.id2token[t.item()]) + + descr = [ ' '.join(t_primer + t_generated) ] + img += [ picoclvr.descr2img(descr) ] + + img = torch.cat(img, 0) + file_name = f'result_picoclvr_{n_epoch:04d}.png' + torchvision.utils.save_image(img / 255., + file_name, nrow = nb_per_primer, pad_value = 0.8) + log_string(f'wrote {file_name}') + +###################################################################### + +class TaskWiki103(Task): + + def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100, + device = torch.device('cpu')): + + self.batch_size = batch_size + self.len_min = len_min + self.len_max = len_max + self.min_freq = min_freq + self.device = device + + self.tokenizer = torchtext.data.get_tokenizer('basic_english') + train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/') + + # Mostly for debug + if args.data_size > 0: + train_iter = itertools.islice(train_iter, args.data_size) + + def yield_tokens(): + for l in tqdm.tqdm(train_iter, desc = 'vocab'): + yield self.tokenizer(l) + + self.vocab = torchtext.vocab.build_vocab_from_iterator( + yield_tokens(), + specials = [ '', '' ], + min_freq = self.min_freq + ) + + self.vocab.set_default_index(self.vocab[ '' ]) + + def tensorize(self, s): + a = max(len(x) for x in s) + return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) + + def yield_batches(self, ds): + s = [ ] + for l in ds: + q = self.tokenizer(l) + if len(q) >= self.len_min and len(q) <= self.len_max: + s += [ q ] + if len(s) == self.batch_size: + yield self.tensorize(s) + s = [ ] + + if len(s) > 0: + yield self.tensorize(s) + + def batches(self, split = 'train'): + data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/') + + # Mostly for debug + if args.data_size > 0: + data_iter = itertools.islice(data_iter, args.data_size) + + return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch')) + + def vocabulary_size(self): + return len(self.vocab) + + def produce_results(self, n_epoch, model, nb_tokens = 50): + file_name = f'result_wiki103_{n_epoch:04d}.txt' + + with open(file_name, 'w') as outfile: + for primer in [ + 'the cat is hunting a', + 'paris is the capital', + 'cars are convenient', + 'the difference between men and women is', + 'the object was blue all over and green all over it was', + 'cherries are red and lemons are', + 'cherries are sweet and lemons are', + 'two plus three equals', + 'deep learning is', + ]: + t_primer = self.tokenizer(primer) + t_generated = [ ] + + for j in range(nb_tokens): + + input = self.tensorize([ t_primer + t_generated ]).to(self.device) + output = model(input) + logits = output[0, -1] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t = dist.sample() + else: + t = logits.argmax() + t_generated.append(self.vocab.lookup_token(t)) + if t_generated[-1] == '': break + + s = ' '.join(t_generated) + + outfile.write(f'<{primer}> {s}\n') + + log_string(f'wrote {file_name}') + +###################################################################### + +class TaskMNIST(Task): + + def __init__(self, batch_size, device = torch.device('cpu')): + self.device = device + self.batch_size = batch_size + + def batches(self, split = 'train'): + assert split in { 'train', 'test' } + data_set = torchvision.datasets.MNIST( + root = './data', train = (split == 'train'), + download = True + ) + data_input = data_set.data.view(-1, 28 * 28).long() + if args.data_size >= 0: + data_input = data_input[:args.data_size] + for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'): + yield batch + + def vocabulary_size(self): + return 256 + + def produce_results(self, n_epoch, model, nb_samples = 64): + results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) + for input in results.split(self.batch_size): + for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): + output = model(input) + logits = output[:, s] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t = dist.sample() + else: + t = logits.argmax(1) + input[:, s + 1] = t + + image_name = f'result_mnist_{n_epoch:04d}.png' + torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., + image_name, nrow = 16, pad_value = 0.8) + log_string(f'wrote {image_name}') + +###################################################################### + +def check_causality(model): + #m = model[1:] + input = torch.rand(1, 5, dim_model).requires_grad_() + output = m(input) + a = torch.zeros(output.size(1), input.size(1)) + for k in range(output.size(1)): + for d in range(output.size(2)): + g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True) + a[k] += g.squeeze(0).pow(2).sum(1) + print(a) + +###################################################################### + +log_string(f'device {device}') + +if args.data == 'wiki103': + task = TaskWiki103(batch_size = args.batch_size, device = device) +elif args.data == 'mnist': + task = TaskMNIST(batch_size = args.batch_size, device = device) +elif args.data == 'picoclvr': + task = TaskPicoCLVR(batch_size = args.batch_size, device = device) +else: + raise ValueError(f'Unknown dataset {args.data}.') + +vocabulary_size = task.vocabulary_size() + +log_string(f'vocabulary_size {vocabulary_size}') + +############################## + +model = mygpt.MyGPT( + vocabulary_size = vocabulary_size, + dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden, + nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout +) + +nb_parameters = sum(p.numel() for p in model.parameters()) +log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') + +model.to(device) + +###################################################################### + +if args.optim == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) +elif args.optim == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) +elif args.optim == 'adamw': + optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) +else: + raise ValueError(f'Unknown optimizer {args.optim}.') + +for k in range(args.nb_epochs): + + model.train() + + nb_train_samples, acc_train_loss = 0, 0.0 + + for input in task.batches(split = 'train'): + input = input.to(device) + output = model(input) + loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.autograd.no_grad(): + + model.eval() + + nb_test_samples, acc_test_loss = 0, 0.0 + + for input in task.batches(split = 'test'): + input = input.to(device) + output = model(input) + loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + acc_test_loss += loss.item() * input.size(0) + nb_test_samples += input.size(0) + + train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) + test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) + + log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}') + + task.produce_results(k, model) + +###################################################################### diff --git a/mygpt.py b/mygpt.py index 13fbe8e..7bf25b5 100755 --- a/mygpt.py +++ b/mygpt.py @@ -5,92 +5,13 @@ # Written by Francois Fleuret -import math, sys, argparse, time, tqdm, itertools +import math + +import torch -import torch, torchtext, torchvision from torch import nn from torch.nn import functional as F -###################################################################### - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -###################################################################### - -parser = argparse.ArgumentParser(description = 'My own GPT.') - -parser.add_argument('--log_filename', - type = str, default = 'train.log') - -parser.add_argument('--download', - type = bool, default = False) - -parser.add_argument('--seed', - type = int, default = 0) - -parser.add_argument('--nb_epochs', - type = int, default = 100) - -parser.add_argument('--batch_size', - type = int, default = 25) - -parser.add_argument('--data', - type = str, default = 'wiki103') - -parser.add_argument('--data_size', - type = int, default = -1) - -parser.add_argument('--optim', - type = str, default = 'adam') - -parser.add_argument('--learning_rate', - type = float, default = 1e-4) - -parser.add_argument('--dim_model', - type = int, default = 512) - -parser.add_argument('--dim_keys', - type = int, default = 64) - -parser.add_argument('--dim_hidden', - type = int, default = 2048) - -parser.add_argument('--nb_heads', - type = int, default = 8) - -parser.add_argument('--nb_blocks', - type = int, default = 12) - -parser.add_argument('--dropout', - type = float, default = 0.1) - -parser.add_argument('--synthesis_sampling', - type = bool, default = True) - -###################################################################### - -args = parser.parse_args() - -log_file = open(args.log_filename, 'w') - -if args.seed >= 0: - torch.manual_seed(args.seed) - -###################################################################### - -def log_string(s): - t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime()) - - if log_file is not None: - log_file.write(t + s + '\n') - log_file.flush() - - print(t + s) - sys.stdout.flush() - -for n in vars(args): - log_string(f'args.{n} {getattr(args, n)}') - ############################## class Residual(nn.Module): @@ -198,321 +119,3 @@ class MyGPT(nn.Module): return x ###################################################################### - -class Task: - def batches(self, split = 'train'): - pass - - def vocabulary_size(self): - pass - - def produce_results(self, n_epoch, model, nb_tokens = 50): - pass - -###################################################################### - -import picoclvr - -class TaskPicoCLVR(Task): - - def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')): - self.batch_size = batch_size - self.device = device - nb = args.data_size if args.data_size > 0 else 250000 - - descr = picoclvr.generate(nb, height = height, width = width) - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] - - tokens = set() - for s in descr: - for t in s: tokens.add(t) - self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) - self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - - t = [ [ self.token2id[u] for u in s ] for s in descr ] - data_input = torch.tensor(t, device = self.device) - - self.test_input = data_input[:nb // 5] - self.train_input = data_input[nb // 5:] - - def batches(self, split = 'train'): - assert split in { 'train', 'test' } - if split == 'train': - for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'): - yield batch - else: - for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'): - yield batch - - def vocabulary_size(self): - return len(self.token2id) - - def produce_results(self, n_epoch, model, nb_tokens = 50): - img = [ ] - nb_per_primer = 8 - - for primer in [ - 'red above green green top blue right of red ', - 'there is red there is yellow there is blue ', - 'red below yellow yellow below green green below blue red right yellow left green right blue left ', - 'green bottom yellow bottom green left of blue yellow right of blue blue top ', - ]: - - for k in range(nb_per_primer): - t_primer = primer.strip().split(' ') - t_generated = [ ] - - for j in range(nb_tokens): - t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] - input = torch.tensor(t, device = self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.id2token[t.item()]) - - descr = [ ' '.join(t_primer + t_generated) ] - img += [ picoclvr.descr2img(descr) ] - - img = torch.cat(img, 0) - file_name = f'result_picoclvr_{n_epoch:04d}.png' - torchvision.utils.save_image(img / 255., - file_name, nrow = nb_per_primer, pad_value = 0.8) - log_string(f'wrote {file_name}') - -###################################################################### - -class TaskWiki103(Task): - - def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100, - device = torch.device('cpu')): - - self.batch_size = batch_size - self.len_min = len_min - self.len_max = len_max - self.min_freq = min_freq - self.device = device - - self.tokenizer = torchtext.data.get_tokenizer('basic_english') - train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/') - - # Mostly for debug - if args.data_size > 0: - train_iter = itertools.islice(train_iter, args.data_size) - - def yield_tokens(): - for l in tqdm.tqdm(train_iter, desc = 'vocab'): - yield self.tokenizer(l) - - self.vocab = torchtext.vocab.build_vocab_from_iterator( - yield_tokens(), - specials = [ '', '' ], - min_freq = self.min_freq - ) - - self.vocab.set_default_index(self.vocab[ '' ]) - - def tensorize(self, s): - a = max(len(x) for x in s) - return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) - - def yield_batches(self, ds): - s = [ ] - for l in ds: - q = self.tokenizer(l) - if len(q) >= self.len_min and len(q) <= self.len_max: - s += [ q ] - if len(s) == self.batch_size: - yield self.tensorize(s) - s = [ ] - - if len(s) > 0: - yield self.tensorize(s) - - def batches(self, split = 'train'): - data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/') - - # Mostly for debug - if args.data_size > 0: - data_iter = itertools.islice(data_iter, args.data_size) - - return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch')) - - def vocabulary_size(self): - return len(self.vocab) - - def produce_results(self, n_epoch, model, nb_tokens = 50): - file_name = f'result_wiki103_{n_epoch:04d}.txt' - - with open(file_name, 'w') as outfile: - for primer in [ - 'the cat is hunting a', - 'paris is the capital', - 'cars are convenient', - 'the difference between men and women is', - 'the object was blue all over and green all over it was', - 'cherries are red and lemons are', - 'cherries are sweet and lemons are', - 'two plus three equals', - 'deep learning is', - ]: - t_primer = self.tokenizer(primer) - t_generated = [ ] - - for j in range(nb_tokens): - - input = self.tensorize([ t_primer + t_generated ]).to(self.device) - output = model(input) - logits = output[0, -1] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax() - t_generated.append(self.vocab.lookup_token(t)) - if t_generated[-1] == '': break - - s = ' '.join(t_generated) - - outfile.write(f'<{primer}> {s}\n') - - log_string(f'wrote {file_name}') - -###################################################################### - -class TaskMNIST(Task): - - def __init__(self, batch_size, device = torch.device('cpu')): - self.device = device - self.batch_size = batch_size - - def batches(self, split = 'train'): - assert split in { 'train', 'test' } - data_set = torchvision.datasets.MNIST( - root = './data', train = (split == 'train'), - download = True - ) - data_input = data_set.data.view(-1, 28 * 28).long() - if args.data_size >= 0: - data_input = data_input[:args.data_size] - for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'): - yield batch - - def vocabulary_size(self): - return 256 - - def produce_results(self, n_epoch, model, nb_samples = 64): - results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) - for input in results.split(self.batch_size): - for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): - output = model(input) - logits = output[:, s] - if args.synthesis_sampling: - dist = torch.distributions.categorical.Categorical(logits = logits) - t = dist.sample() - else: - t = logits.argmax(1) - input[:, s + 1] = t - - image_name = f'result_mnist_{n_epoch:04d}.png' - torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., - image_name, nrow = 16, pad_value = 0.8) - log_string(f'wrote {image_name}') - -###################################################################### - -def check_causality(model): - #m = model[1:] - input = torch.rand(1, 5, dim_model).requires_grad_() - output = m(input) - a = torch.zeros(output.size(1), input.size(1)) - for k in range(output.size(1)): - for d in range(output.size(2)): - g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True) - a[k] += g.squeeze(0).pow(2).sum(1) - print(a) - -###################################################################### - -log_string(f'device {device}') - -if args.data == 'wiki103': - task = TaskWiki103(batch_size = args.batch_size, device = device) -elif args.data == 'mnist': - task = TaskMNIST(batch_size = args.batch_size, device = device) -elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, device = device) -else: - raise ValueError(f'Unknown dataset {args.data}.') - -vocabulary_size = task.vocabulary_size() - -log_string(f'vocabulary_size {vocabulary_size}') - -############################## - -model = MyGPT( - vocabulary_size = vocabulary_size, - dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden, - nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout -) - -nb_parameters = sum(p.numel() for p in model.parameters()) -log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') - -model.to(device) - -###################################################################### - -if args.optim == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) -elif args.optim == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) -else: - raise ValueError(f'Unknown optimizer {args.optim}.') - -for k in range(args.nb_epochs): - - model.train() - - nb_train_samples, acc_train_loss = 0, 0.0 - - for input in task.batches(split = 'train'): - input = input.to(device) - output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - with torch.autograd.no_grad(): - - model.eval() - - nb_test_samples, acc_test_loss = 0, 0.0 - - for input in task.batches(split = 'test'): - input = input.to(device) - output = model(input) - loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) - - train_perplexity = math.exp(min(100, acc_train_loss/nb_train_samples)) - test_perplexity = math.exp(min(100, acc_test_loss/nb_test_samples)) - - log_string(f'perplexity {k+1} train {train_perplexity} test {test_perplexity}') - - task.produce_results(k, model) - -######################################################################