From: Francois Fleuret Date: Sun, 24 Apr 2022 08:18:51 +0000 (+0200) Subject: Initial commit X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=68c17359790a9b8ac931a3679f08ad6a82a4e640;p=mygpt.git Initial commit --- 68c17359790a9b8ac931a3679f08ad6a82a4e640 diff --git a/mygpt.py b/mygpt.py new file mode 100755 index 0000000..970ee7b --- /dev/null +++ b/mygpt.py @@ -0,0 +1,514 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys, argparse, time, tqdm, itertools + +import torch, torchtext, torchvision +from torch import nn +from torch.nn import functional as F + +###################################################################### + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +###################################################################### + +parser = argparse.ArgumentParser(description = 'My own GPT.') + +parser.add_argument('--log_filename', + type = str, default = 'train.log') + +parser.add_argument('--download', + type = bool, default = False) + +parser.add_argument('--seed', + type = int, default = 0) + +parser.add_argument('--nb_epochs', + type = int, default = 100) + +parser.add_argument('--batch_size', + type = int, default = 25) + +parser.add_argument('--data', + type = str, default = 'wiki103') + +parser.add_argument('--data_size', + type = int, default = -1) + +parser.add_argument('--optim', + type = str, default = 'adam') + +parser.add_argument('--learning_rate', + type = float, default = 1e-4) + +parser.add_argument('--dim_model', + type = int, default = 512) + +parser.add_argument('--dim_keys', + type = int, default = 64) + +parser.add_argument('--dim_hidden', + type = int, default = 2048) + +parser.add_argument('--nb_heads', + type = int, default = 8) + +parser.add_argument('--nb_blocks', + type = int, default = 12) + +parser.add_argument('--dropout', + type = float, default = 0.1) + +parser.add_argument('--synthesis_sampling', + type = bool, default = True) + +###################################################################### + +args = parser.parse_args() + +log_file = open(args.log_filename, 'w') + +if args.seed >= 0: + torch.manual_seed(args.seed) + +###################################################################### + +def log_string(s): + t = time.strftime('%Y%m%d-%H:%M:%S ', time.localtime()) + + if log_file is not None: + log_file.write(t + s + '\n') + log_file.flush() + + print(t + s) + sys.stdout.flush() + +for n in vars(args): + log_string(f'args.{n} {getattr(args, n)}') + +############################## + +class Residual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, x): + return x + self.f(x) + +############################## + +class PositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + # From Vaswani et al 2018 + # PE_{t,2i} = sin(t/(L^{2i/D})) + # PE_{t,2i+1} = cos(t/(L^{2i/D})) + def forward(self, x): + t = torch.arange(x.size(1), dtype = x.dtype, device = x.device)[:, None] + j = torch.arange(x.size(2), dtype = x.dtype, device = x.device)[None, :] + k = j%2 + return x + torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi/2 * k)[None, :, :] + +############################## + +class QKVAttention(nn.Module): + def __init__(self, dim_in, dim_qk, dim_v, nb_heads = 1, causal = False, attention_dropout = 0.0): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.empty(*d).normal_(0, 1 / math.sqrt(d[-1]))) + + self.wq = randw(nb_heads, dim_qk, dim_in) + self.wk = randw(nb_heads, dim_qk, dim_in) + self.wv = randw(nb_heads, dim_v, dim_in) + self.causal = causal + self.attention_dropout = attention_dropout + + def forward(self, x): + q = torch.einsum('ntc,hdc->nhtd', x, self.wq) + k = torch.einsum('ntc,hdc->nhtd', x, self.wk) + v = torch.einsum('ntc,hdc->nhtd', x, self.wv) + r = math.sqrt(q.size(3)) + a = torch.einsum('nhtd,nhsd->nhts', q, k).div(r) + if self.causal: + mask = torch.tril(q.new_ones(a.size(2), a.size(3)))[None, None, :, :] == 0 + a = a.masked_fill(mask, float('-inf')) + a = a.softmax(dim = 3) + a = F.dropout(a, self.attention_dropout, self.training) + y = torch.einsum('nhts,nhsd->nhtd', a, v) + return y.permute(0, 2, 1, 3).flatten(2) # nhtd -> nt(hd) + +############################## + +class MyGPT(nn.Module): + def __init__(self, + vocabulary_size, + dim_model, dim_keys, dim_hidden, + nb_heads, nb_blocks, dropout = 0.): + + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + nn.Embedding(vocabulary_size, dim_model), + nn.Dropout(dropout), + PositionalEncoding(len_max = 1e5), + ) + + trunk_blocks = [ ] + + for _ in range(nb_blocks): + trunk_blocks += [ + Residual( + nn.LayerNorm(dim_model), + QKVAttention( + dim_in = dim_model, + dim_qk = dim_keys, dim_v = dim_model // nb_heads, + nb_heads = nb_heads, + causal = True, attention_dropout = dropout + ), + nn.Linear(in_features = dim_model, out_features = dim_model), + ), + Residual( + nn.LayerNorm(dim_model), + nn.Linear(in_features = dim_model, out_features = dim_hidden), + nn.ReLU(), + nn.Linear(in_features = dim_hidden, out_features = dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features = dim_model, out_features = vocabulary_size) + + def forward(self, x): + x = self.embedding(x) + x = self.trunk(x) + x = self.readout(x) + return x + +###################################################################### + +class Task: + def batches(self, split = 'train'): + pass + + def vocabulary_size(self): + pass + + def produce_results(self, n_epoch, model, nb_tokens = 50): + pass + +###################################################################### + +import picoclvr + +class TaskPicoCLVR(Task): + + def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')): + self.batch_size = batch_size + self.device = device + nb = args.data_size if args.data_size > 0 else 250000 + + descr = picoclvr.generate(nb, height = height, width = width) + descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in descr ]) + descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + + tokens = set() + for s in descr: + for t in s: tokens.add(t) + self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) + self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) + + t = [ [ self.token2id[u] for u in s ] for s in descr ] + data_input = torch.tensor(t, device = self.device) + + self.test_input = data_input[:nb // 5] + self.train_input = data_input[nb // 5:] + + def batches(self, split = 'train'): + assert split in { 'train', 'test' } + if split == 'train': + for batch in tqdm.tqdm(self.train_input.split(self.batch_size), desc = 'epoch'): + yield batch + else: + for batch in tqdm.tqdm(self.test_input.split(self.batch_size), desc = 'epoch'): + yield batch + + def vocabulary_size(self): + return len(self.token2id) + + def produce_results(self, n_epoch, model, nb_tokens = 50): + img = [ ] + nb_per_primer = 8 + for primer in [ + 'red above green green top blue right of red ', + 'there is red there is yellow there is blue ', + 'red below yellow yellow below green green below blue red right yellow left green right blue left ', + 'green bottom yellow bottom green left of blue yellow right of blue blue top ', + ]: + + for k in range(nb_per_primer): + t_primer = primer.strip().split(' ') + t_generated = [ ] + + for j in range(nb_tokens): + t = [ [ self.token2id[u] for u in t_primer + t_generated ] ] + input = torch.tensor(t, device = self.device) + output = model(input) + logits = output[0, -1] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t = dist.sample() + else: + t = logits.argmax() + t_generated.append(self.id2token[t.item()]) + + descr = [ ' '.join(t_primer + t_generated) ] + img += [ picoclvr.descr2img(descr) ] + + img = torch.cat(img, 0) + file_name = f'result_picoclvr_{n_epoch:04d}.png' + torchvision.utils.save_image(img / 255., + file_name, nrow = nb_per_primer, pad_value = 0.8) + log_string(f'wrote {file_name}') + +###################################################################### + +class TaskWiki103(Task): + + def __init__(self, batch_size, len_min = 10, len_max = 200, min_freq = 100, + device = torch.device('cpu')): + + self.batch_size = batch_size + self.len_min = len_min + self.len_max = len_max + self.min_freq = min_freq + self.device = device + + self.tokenizer = torchtext.data.get_tokenizer('basic_english') + train_iter = torchtext.datasets.WikiText103(split = 'train', root = './data/nlp/') + + # Mostly for debug + if args.data_size > 0: + train_iter = itertools.islice(train_iter, args.data_size) + + def yield_tokens(): + for l in tqdm.tqdm(train_iter, desc = 'vocab'): + yield self.tokenizer(l) + + self.vocab = torchtext.vocab.build_vocab_from_iterator( + yield_tokens(), + specials = [ '', '' ], + min_freq = self.min_freq + ) + + self.vocab.set_default_index(self.vocab[ '' ]) + + def tensorize(self, s): + a = max(len(x) for x in s) + return torch.tensor([ self.vocab(x + [ '' ] * (a - len(x))) for x in s ]) + + def yield_batches(self, ds): + s = [ ] + for l in ds: + q = self.tokenizer(l) + if len(q) >= self.len_min and len(q) <= self.len_max: + s += [ q ] + if len(s) == self.batch_size: + yield self.tensorize(s) + s = [ ] + + if len(s) > 0: + yield self.tensorize(s) + + def batches(self, split = 'train'): + data_iter = torchtext.datasets.WikiText103(split = split, root = './data/nlp/') + + # Mostly for debug + if args.data_size > 0: + data_iter = itertools.islice(data_iter, args.data_size) + + return self.yield_batches(tqdm.tqdm(data_iter, desc = 'epoch')) + + def vocabulary_size(self): + return len(self.vocab) + + def produce_results(self, n_epoch, model, nb_tokens = 50): + file_name = f'result_wiki103_{n_epoch:04d}.txt' + + with open(file_name, 'w') as outfile: + for primer in [ + 'the cat is hunting a', + 'paris is the capital', + 'cars are convenient', + 'the difference between men and women is', + 'the object was blue all over and green all over it was', + 'cherries are red and lemons are', + 'cherries are sweet and lemons are', + 'two plus three equals', + 'deep learning is', + ]: + t_primer = self.tokenizer(primer) + t_generated = [ ] + + for j in range(nb_tokens): + + input = self.tensorize([ t_primer + t_generated ]).to(self.device) + output = model(input) + logits = output[0, -1] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t = dist.sample() + else: + t = logits.argmax() + t_generated.append(self.vocab.lookup_token(t)) + if t_generated[-1] == '': break + + s = ' '.join(t_generated) + + outfile.write(f'<{primer}> {s}\n') + + log_string(f'wrote {file_name}') + +###################################################################### + +class TaskMNIST(Task): + + def __init__(self, batch_size, device = torch.device('cpu')): + self.device = device + self.batch_size = batch_size + + def batches(self, split = 'train'): + assert split in { 'train', 'test' } + data_set = torchvision.datasets.MNIST( + root = './data', train = (split == 'train'), + download = True + ) + data_input = data_set.data.view(-1, 28 * 28).long() + if args.data_size >= 0: + data_input = data_input[:args.data_size] + for batch in tqdm.tqdm(data_input.split(self.batch_size), desc = 'epoch'): + yield batch + + def vocabulary_size(self): + return 256 + + def produce_results(self, n_epoch, model, nb_samples = 64): + results = torch.zeros(nb_samples, 28 * 28, dtype = torch.int64, device = self.device) + for input in results.split(self.batch_size): + for s in tqdm.tqdm(range(input.size(1) - 1), desc = 'synth'): + output = model(input) + logits = output[:, s] + if args.synthesis_sampling: + dist = torch.distributions.categorical.Categorical(logits = logits) + t = dist.sample() + else: + t = logits.argmax(1) + input[:, s + 1] = t + + image_name = f'result_mnist_{n_epoch:04d}.png' + torchvision.utils.save_image(1 - results.reshape(-1, 1, 28, 28) / 255., + image_name, nrow = 16, pad_value = 0.8) + log_string(f'wrote {image_name}') + +###################################################################### + +def check_causality(model): + #m = model[1:] + input = torch.rand(1, 5, dim_model).requires_grad_() + output = m(input) + a = torch.zeros(output.size(1), input.size(1)) + for k in range(output.size(1)): + for d in range(output.size(2)): + g, = torch.autograd.grad(output[0, k, d], input, retain_graph = True) + a[k] += g.squeeze(0).pow(2).sum(1) + print(a) + +###################################################################### + +log_string(f'device {device}') + +if args.data == 'wiki103': + task = TaskWiki103(batch_size = args.batch_size, device = device) +elif args.data == 'mnist': + task = TaskMNIST(batch_size = args.batch_size, device = device) +elif args.data == 'picoclvr': + task = TaskPicoCLVR(batch_size = args.batch_size, device = device) +else: + raise ValueError(f'Unknown dataset {args.data}.') + +vocabulary_size = task.vocabulary_size() + +log_string(f'vocabulary_size {vocabulary_size}') + +############################## + +model = MyGPT( + vocabulary_size = vocabulary_size, + dim_model = args.dim_model, dim_keys = args.dim_keys, dim_hidden = args.dim_hidden, + nb_heads = args.nb_heads, nb_blocks = args.nb_blocks, dropout = args.dropout +) + +nb_parameters = sum(p.numel() for p in model.parameters()) +log_string(f'nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)') + +model.to(device) + +###################################################################### + +if args.optim == 'sgd': + optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate) +elif args.optim == 'adam': + optimizer = torch.optim.Adam(model.parameters(), lr = args.learning_rate) +elif args.optim == 'adamw': + optimizer = torch.optim.AdamW(model.parameters(), lr = args.learning_rate) +else: + raise ValueError(f'Unknown optimizer {args.optim}.') + +for k in range(args.nb_epochs): + + model.train() + + nb_train_samples, acc_train_loss = 0, 0.0 + + for input in task.batches(split = 'train'): + input = input.to(device) + output = model(input) + loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + acc_train_loss += loss.item() * input.size(0) + nb_train_samples += input.size(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.autograd.no_grad(): + + model.eval() + + nb_test_samples, acc_test_loss = 0, 0.0 + + for input in task.batches(split = 'test'): + input = input.to(device) + output = model(input) + loss = F.cross_entropy(output[:, :-1].transpose(1, 2), input[:, 1:]) + acc_test_loss += loss.item() * input.size(0) + nb_test_samples += input.size(0) + + log_string(f'perplexity {k+1} train {math.exp(min(100, acc_train_loss/nb_train_samples))} test {math.exp(min(100, acc_test_loss/nb_test_samples))}') + + task.produce_results(k, model) + +###################################################################### diff --git a/picoclvr.py b/picoclvr.py new file mode 100755 index 0000000..a194a1c --- /dev/null +++ b/picoclvr.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python + +import torch, torchvision + +colors = [ + [ 255, 255, 255 ], + [ 255, 0, 0 ], + [ 0, 255, 0 ], + [ 0, 0, 255 ], + [ 255, 255, 0 ], + [ 0, 0, 0 ], +] + +color_names = [ + 'white', + 'red', + 'green', + 'blue', + 'yellow', + 'black', +] + +color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] ) + +def generate(nb, height = 6, width = 8, max_nb_statements = 10): + + descr = [ ] + + for n in range(nb): + nb = torch.randint(5, (1,)) + 1 + shape_position = torch.randperm(height * width)[:nb] + shape_c = torch.randperm(5)[:nb] + 1 + shape_i = shape_position.div(width, rounding_mode = 'floor') + shape_j = shape_position % width + + img = [ 0 ] * height * width + for k in range(nb): img[shape_position[k]] = shape_c[k] + + s = [ ] + + for r, c in [ (k, color_names[shape_c[k]]) for k in range(nb) ]: + s += [ f'there is {c}' ] + + if shape_i[r] >= height - height/4: s += [ f'{c} bottom' ] + if shape_i[r] < height/4: s += [ f'{c} top' ] + if shape_j[r] >= width - width/4: s += [ f'{c} right' ] + if shape_j[r] < width/4: s += [ f'{c} left' ] + + for t, d in [ (k, color_names[shape_c[k]]) for k in range(nb) ]: + if shape_i[r] > shape_i[t]: s += [ f'{c} below {d}' ] + if shape_i[r] < shape_i[t]: s += [ f'{c} above {d}' ] + if shape_j[r] > shape_j[t]: s += [ f'{c} right of {d}' ] + if shape_j[r] < shape_j[t]: s += [ f'{c} left of {d}' ] + + nb_statements = torch.randint(max_nb_statements, (1,)) + 1 + s = ' '.join([ s[k] for k in torch.randperm(len(s))[:nb_statements] ] ) + s += ' ' + ' '.join([ f'{color_names[n]}' for n in img ]) + descr += [ s ] + + return descr + +###################################################################### + +def descr2img(descr, height = 6, width = 8): + + def token2color(t): + try: + return color_tokens[t] + except KeyError: + return [ 128, 128, 128 ] + + def img_descr(x): + u = x.split('', 1) + return u[1] if len(u) > 1 else '' + + img = torch.full((len(descr), 3, height, width), 255) + d = [ img_descr(x) for x in descr ] + d = [ u.strip().split(' ')[:height * width] for u in d ] + d = [ u + [ '' ] * (height * width - len(u)) for u in d ] + d = [ [ token2color(t) for t in u ] for u in d ] + img = torch.tensor(d).permute(0, 2, 1) + img = img.reshape(img.size(0), 3, height, width) + + return img + +###################################################################### + +if __name__ == '__main__': + descr = generate(5) + img = descr2img(descr) + print(descr, img.size()) + torchvision.utils.save_image(img / 255., + 'example.png', nrow = 16, pad_value = 0.8) + + import time + + start_time = time.perf_counter() + descr = generate(10000) + end_time = time.perf_counter() + print(f'{len(descr) / (end_time - start_time):.02f} samples per second') + +######################################################################