From ed60c541ca2225d69df96c2a382bb83c947bfe0e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Thu, 22 Jun 2023 08:29:10 +0200 Subject: [PATCH] Update. --- main.py | 30 ++++++++++++++++++++---------- mygpt.py | 31 ++++++++++++++++++------------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/main.py b/main.py index 7cb8d4f..db982ca 100755 --- a/main.py +++ b/main.py @@ -173,15 +173,27 @@ for n in vars(args): ###################################################################### +# ra_mask is boolean, with 1s on the values to generate + + def masked_inplace_autoregression( - model, batch_size, input, ar_mask, forbidden_tokens=None, device=torch.device("cpu") + model, + batch_size, + input, + ar_mask, + forbidden_tokens=None, + progress_bar_desc="autoregression", + device=torch.device("cpu"), ): - for input, ar_mask in tqdm.tqdm( - zip(input.split(batch_size), ar_mask.split(batch_size)), - dynamic_ncols=True, - desc="autoregression", - total=input.size(0) // batch_size, - ): + batches = zip(input.split(batch_size), ar_mask.split(batch_size)) + if progress_bar_desc is not None: + tqdm.tqdm( + batches, + dynamic_ncols=True, + desc=progress_bar_desc, + total=input.size(0) // batch_size, + ) + for input, ar_mask in batches: i = (ar_mask.sum(0) > 0).nonzero() if i.min() > 0: model( @@ -317,6 +329,7 @@ class TaskPicoCLVR(Task): input, ar_masks, forbidden_tokens, + progress_bar_desc=None, device=self.device, ) model.train(t) @@ -975,9 +988,6 @@ for n_epoch in range(nb_epochs_finished, nb_epochs): for input in task.batches(split="test"): input = input.to(device) - # input, loss_masks, true_images = task.excise_last_image(input) - # input, loss_masks = task.add_true_image(input, true_images, loss_masks) - output = model(mygpt.BracketedSequence(input)).x loss = F.cross_entropy(output.transpose(1, 2), input) acc_test_loss += loss.item() * input.size(0) diff --git a/mygpt.py b/mygpt.py index b4446c6..6a12a5a 100755 --- a/mygpt.py +++ b/mygpt.py @@ -5,6 +5,11 @@ # Written by Francois Fleuret +# This is an implementation from scratch of a "GPT", that is a model +# composed of several causal self-attention blocks. It is equipped +# with a caching mechanism for keys and values to avoid a O(N^3) cost +# for auto-regression. + import math import torch @@ -14,19 +19,6 @@ from torch.nn import functional as F ###################################################################### - -class WithResidual(nn.Module): - def __init__(self, *f): - super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) - - def forward(self, bs): - bs.x = bs.x + self.f(bs).x - return bs - - -###################################################################### - # A BracketedSequence is a BxTx... tensor with a first and a nb time # steps to compute. @@ -78,6 +70,19 @@ class CacheWrapper(nn.Module): ############################## +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, bs): + bs.x = bs.x + self.f(bs).x + return bs + + +############################## + + class AddPositionalEncoding(nn.Module): def __init__(self, len_max): super().__init__() -- 2.39.5