From: François Fleuret Date: Tue, 1 Oct 2024 05:12:59 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f730f5fc1003f74ae7afea3451f17ad8925bd909;p=pytorch.git Update. --- diff --git a/tinygen.py b/tinygen.py new file mode 100755 index 0000000..66c005c --- /dev/null +++ b/tinygen.py @@ -0,0 +1,520 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +# @XREMOTE_HOST: elk.fleuret.org +# @XREMOTE_EXEC: python +# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate +# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data +# @XREMOTE_GET: *.png + +import sys, argparse, time, math + +import torch, torchvision + +from torch import optim, nn +from torch.nn import functional as F +from tqdm import tqdm + +###################################################################### + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +###################################################################### + +parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.") + +parser.add_argument("--nb_epochs", type=int, default=250) + +parser.add_argument("--batch_size", type=int, default=100) + +parser.add_argument("--data_dir", type=str, default="./data/") + +parser.add_argument("--log_filename", type=str, default="train.log") + +parser.add_argument("--embedding_dim", type=int, default=64) + +parser.add_argument("--nb_channels", type=int, default=64) + +args = parser.parse_args() + +log_file = open(args.log_filename, "w") + +###################################################################### + + +def log_string(s): + t = time.strftime("%Y-%m-%d_%H:%M:%S - ", time.localtime()) + + if log_file is not None: + log_file.write(t + s + "\n") + log_file.flush() + + print(t + s) + sys.stdout.flush() + + +###################################################################### + + +class VaswaniPositionalEncoding(nn.Module): + def __init__(self, len_max): + super().__init__() + self.len_max = len_max + + def forward(self, x): + t = torch.arange(x.size(1), dtype=x.dtype, device=x.device)[:, None] + j = torch.arange(x.size(2), dtype=x.dtype, device=x.device)[None, :] + k = j % 2 # works with float, weird + pe = torch.sin(t / (self.len_max ** ((j - k) / x.size(2))) + math.pi / 2 * k) + y = x + pe + return y + + +###################################################################### + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, x): + return x + self.f(x) + + +###################################################################### + + +def vanilla_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + +###################################################################### + + +class MHAttention(nn.Module): + def __init__( + self, + dim_model, + dim_qk, + dim_v, + nb_heads=1, + attention=vanilla_attention, + attention_dropout=0.0, + ): + super().__init__() + + def randw(*d): + return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) + + self.attention = attention + self.attention_dropout = attention_dropout + self.w_q = randw(nb_heads, dim_qk, dim_model) + self.w_k = randw(nb_heads, dim_qk, dim_model) + self.w_v = randw(nb_heads, dim_v, dim_model) + self.w_o = randw(nb_heads, dim_v, dim_model) + + def forward(self, x_q, x_kv=None): + if x_kv is None: + x_kv = x_q + + q = torch.einsum("ntc,hdc->nhtd", x_q, self.w_q) + k = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_k) + v = torch.einsum("nsc,hdc->nhsd", x_kv, self.w_v) + y = self.attention(q, k, v) + y = torch.einsum("nhtd,hdc->ntc", y, self.w_o) + + return y + + +###################################################################### + + +class AttentionAE(nn.Module): + def __init__( + self, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.embedding = nn.Sequential( + nn.Linear(2, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + for b in range(nb_blocks): + trunk_blocks += [ + WithResidual( + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=vanilla_attention, + attention_dropout=dropout, + ), + ), + WithResidual( + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=1) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = x.reshape(-1, 2, 28 * 28).permute(0, 2, 1) + x = self.embedding(x) + x = self.positional_encoding(x) + x = self.trunk(x) + x = self.readout(x).reshape(-1, 1, 28, 28) + return x + + +###################################################################### + + +class WithMaskedResidual(nn.Module): + def __init__(self, masker, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + self.masker = masker + self.mask = None + + def forward(self, x): + if self.mask is None: + self.mask = self.masker(x) + return self.mask * x + self.f(x) + + +###################################################################### + + +class FunctionalAttentionAE(nn.Module): + def __init__( + self, + vocabulary_size, + dim_model, + dim_keys, + dim_hidden, + nb_heads, + nb_blocks, + nb_work_tokens=100, + dropout=0.0, + len_max=1e5, + ): + super().__init__() + + assert dim_model % nb_heads == 0 + + self.nb_work_tokens = nb_work_tokens + + self.embedding = nn.Sequential( + nn.Embedding(2 * vocabulary_size, dim_model), + nn.Dropout(dropout), + ) + + self.positional_encoding = VaswaniPositionalEncoding(len_max) + + trunk_blocks = [] + + def no_peek_attention(q, k, v): + a = torch.einsum("nhtd,nhsd->nhts", q, k) / math.sqrt(q.size(3)) + n = self.nb_work_tokens + s = (q.size(2) - n) // 2 + a[:, :, n + 1 * s : n + 2 * s, n + 0 * s : n + 1 * s] = float("-inf") + a[:, :, n + 0 * s : n + 1 * s, n + 1 * s : n + 2 * s] = float("-inf") + a = a.softmax(dim=3) + y = torch.einsum("nhts,nhsd->nhtd", a, v) + return y + + def masker(x): + m = torch.arange(x.size(1), device=x.device) >= self.nb_work_tokens + return m[None, :, None] + + for b in range(nb_blocks): + trunk_blocks += [ + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + MHAttention( + dim_model=dim_model, + dim_qk=dim_keys, + dim_v=dim_model // nb_heads, + nb_heads=nb_heads, + attention=no_peek_attention, + attention_dropout=dropout, + ), + ), + WithMaskedResidual( + masker, + nn.LayerNorm((dim_model,)), + nn.Linear(in_features=dim_model, out_features=dim_hidden), + nn.ReLU(), + nn.Linear(in_features=dim_hidden, out_features=dim_model), + nn.Dropout(dropout), + ), + ] + + self.trunk = nn.Sequential(*trunk_blocks) + + self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size) + + with torch.no_grad(): + for m in self.modules(): + if isinstance(m, nn.Embedding): + m.weight.normal_(mean=0, std=2e-2) + elif isinstance(m, nn.LayerNorm): + m.bias.zero_() + m.weight.fill_(1.0) + + def forward(self, x): + x = self.embedding(x) + x = F.pad(x, (0, 0, self.nb_work_tokens, 0)) + x = self.positional_encoding(x) + x = self.trunk(x) + x = F.pad(x, (0, 0, -self.nb_work_tokens, 0)) + x = self.readout(x) + return x + + +###################################################################### + + +class FullAveragePooling(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.view(x.size(0), x.size(1), -1).mean(2).view(x.size(0), x.size(1), 1, 1) + return x + + +class ResNetBlock(nn.Module): + def __init__(self, nb_channels, kernel_size): + super().__init__() + + self.conv1 = nn.Conv2d( + nb_channels, + nb_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + + self.conv2 = nn.Conv2d( + nb_channels, + nb_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + + def forward(self, x): + y = F.relu(self.conv1(x)) + y = F.relu(x + self.conv2(y)) + return y + + +###################################################################### + + +class ResAutoEncoder(nn.Module): + def __init__(self, nb_channels, kernel_size): + super().__init__() + + self.encoder = nn.Conv2d( + 2, nb_channels, kernel_size=kernel_size, padding=kernel_size // 2 + ) + self.core = nn.Sequential( + *[ResNetBlock(nb_channels, kernel_size) for _ in range(20)] + ) + self.decoder = nn.Conv2d( + nb_channels, 1, kernel_size=kernel_size, padding=kernel_size // 2 + ) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +###################################################################### + + +class AutoEncoder(nn.Module): + def __init__(self, nb_channels, embedding_dim): + super().__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(1, nb_channels, kernel_size=5), # to 24x24 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=5), # to 20x20 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=4, stride=2), # to 9x9 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, nb_channels, kernel_size=3, stride=2), # to 4x4 + nn.ReLU(inplace=True), + nn.Conv2d(nb_channels, embedding_dim, kernel_size=4), + ) + + self.decoder = nn.Sequential( + nn.ConvTranspose2d(embedding_dim, nb_channels, kernel_size=4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=3, stride=2 + ), # from 4x4 + nn.ReLU(inplace=True), + nn.ConvTranspose2d( + nb_channels, nb_channels, kernel_size=4, stride=2 + ), # from 9x9 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, nb_channels, kernel_size=5), # from 20x20 + nn.ReLU(inplace=True), + nn.ConvTranspose2d(nb_channels, 1, kernel_size=5), # from 24x24 + ) + + def encode(self, x): + return self.encoder(x).view(x.size(0), -1) + + def decode(self, z): + return self.decoder(z.view(z.size(0), -1, 1, 1)) + + def forward(self, x): + x = self.encoder(x) + x = self.decoder(x) + return x + + +###################################################################### + +train_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=True, download=True +) +train_input = train_set.data.view(-1, 1, 28, 28).float() + +test_set = torchvision.datasets.MNIST( + args.data_dir + "/mnist/", train=False, download=True +) +test_input = test_set.data.view(-1, 1, 28, 28).float() + +###################################################################### + +model = AutoEncoder(args.nb_channels, args.embedding_dim) + +# model = AttentionAE( +# dim_model=16, +# dim_keys=16, +# dim_hidden=16, +# nb_heads=4, +# nb_blocks=4, +# dropout=0.0, +# len_max=1e5, +# ) + +# model = ResAutoEncoder(nb_channels=128, kernel_size=9) + +print(model) + +optimizer = optim.Adam(model.parameters(), lr=1e-3) + +model.to(device) + +train_input, test_input = train_input.to(device), test_input.to(device) + +mu, std = train_input.mean(), train_input.std() +train_input.sub_(mu).div_(std) +test_input.sub_(mu).div_(std) + +nb_iterations = 10 + +###################################################################### + + +def dist(u, v): + return (u - v).pow(2).sum(dim=(1, 2, 3), keepdim=True).sqrt() + + +def pb(e, desc): + return tqdm( + e, + dynamic_ncols=True, + desc=desc, + total=train_input.size(0) // args.batch_size, + delay=10, + ) + + +for n_epoch in range(args.nb_epochs): + acc_loss = 0 + + for targets in pb(train_input.split(args.batch_size), "train"): + input = torch.randn(targets.size(), device=targets.device) + + loss = 0 + for n in range(nb_iterations): + output = model(input) + current_d = dist(targets, output) + nb_remain = nb_iterations - n + tolerated_d = dist(targets, input) * (nb_remain - 1) / nb_remain + a = (tolerated_d / (current_d + 1e-6)).clamp(max=1) + loss += (1 - a).mean() / nb_iterations + input = targets - a * (targets - output.detach()) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc_loss += loss.item() + + log_string(f"acc_loss {n_epoch} {acc_loss}") + + ###################################################################### + + input = test_input[:256] + model.eval() + + input = torch.randn(input.size(), device=input.device) + for _ in range(nb_iterations): + output = model(input) + input = output.detach() + + output = (output * std + mu) / 255 + + torchvision.utils.save_image( + 1 - output, f"output_{n_epoch:04d}.png", nrow=16, pad_value=0.8 + ) + + +######################################################################