From: François Fleuret Date: Thu, 19 Sep 2024 11:20:34 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=refs%2Fheads%2Fdiffusion;p=culture.git Update. --- diff --git a/grids.py b/grids.py index 4254b32..5e623cb 100755 --- a/grids.py +++ b/grids.py @@ -384,6 +384,9 @@ class Grids(problem.Problem): ###################################################################### + def vocabulary_size(self): + return self.nb_token_values + def grid2img(self, x, scale=15, grids=True): m = torch.logical_and(x >= 0, x < self.nb_colors).long() y = self.colors[x * m].permute(0, 3, 1, 2) diff --git a/main.py b/main.py index 0c40f95..ef340ea 100755 --- a/main.py +++ b/main.py @@ -11,10 +11,7 @@ import torch, torchvision from torch import nn from torch.nn import functional as F -import ffutils - -import mygpt -import sky, grids +import ffutils, grids, attae import threading, subprocess @@ -313,7 +310,7 @@ def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1): log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}") -vocabulary_size = problem.nb_token_values +vocabulary_size = problem.vocabulary_size() log_string(f"vocabulary_size {vocabulary_size}") @@ -640,8 +637,6 @@ def one_complete_epoch(model, n_epoch, c_quizzes, local_device=main_device): ###################################################################### -import attae - models = [] for i in range(args.nb_models): diff --git a/mygpt.py b/mygpt.py deleted file mode 100755 index 5b56264..0000000 --- a/mygpt.py +++ /dev/null @@ -1,475 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -# This is an implementation from scratch of a "GPT", that is a model -# composed of several causal self-attention blocks. It is equipped -# with a caching mechanism for keys and values to avoid a O(N^3) cost -# for auto-regression. - -import math - -import torch - -from torch import nn -from torch.nn import functional as F - -###################################################################### - - -class BSQ(nn.Module): - def __init__(self, L): - super().__init__() - self.L = L - - def forward(self, input, indexes=False): - norm = input.pow(2).sum(dim=2, keepdim=True).sqrt() - u = input / norm - - if indexes: - return ((u >= 0).long() * (2 ** torch.arange(self.L))[None, :]).sum(dim=1) - - hat_u = 1 / math.sqrt(self.L) * (2 * (u >= 0).float() - 1) - if self.training: - self.loss += u.mean(dim=0).tanh().pow(2).mean() - return hat_u + u - u.detach() - else: - return hat_u - - -class RandomBypass(nn.Module): - def __init__(self, m, p): - super().__init__() - self.m = m - self.p = p - - def forward(self, x): - y = self.m(x) - - if self.training: - u = (torch.rand(x.size(0), device=x.device) <= self.p).long()[:, None] - return (u * x.flatten(1) + (1 - u) * y.flatten(1)).reshape(x.size()) - else: - return y - - -###################################################################### - -# A BracketedSequence is a BxTx... tensor with a first and a nb time -# steps to compute. - -# Modules able to process it expect that they will have to process a -# first bracket starting at t=0, followed by a succession of brackets -# that move forward in time, do not overlap, and cover the axis T with -# no holes. -# -# Although it is more general, for a classical prompt-conditioned -# auto-regressive process it will be a first bracket starting at 0 and -# of arbitrary length for the "prompt", followed by brackets of length -# 1 for the successive tokens. -# -# Modules able to process brackets may implement a cache that is -# resetted when the input bracket starts at t=0 - - -class BracketedSequence: - def __init__(self, x, first=None, nb=None): - self.x = x - self.first = 0 if first is None else first - self.nb = x.size(1) if nb is None else nb - - def slice(self): - return self.x[:, self.first : self.first + self.nb] - - def complete(self): - return self.first == 0 and self.nb == self.x.size(1) - - -###################################################################### - - -class CacheWrapper(nn.Module): - def __init__(self, *f): - super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) - - def forward(self, bs): - if bs.first == 0: - y = self.f(bs.slice()) - self.cache_y = y.new(*((y.size(0), bs.x.size(1)) + y.size()[2:])) - self.cache_y[:, bs.first : bs.first + bs.nb] = y - else: - self.cache_y[:, bs.first : bs.first + bs.nb] = self.f(bs.slice()) - - return BracketedSequence(self.cache_y, bs.first, bs.nb) - - -############################## - - -class CachedWithResidual(nn.Module): - def __init__(self, *f): - super().__init__() - self.f = f[0] if len(f) == 1 else nn.Sequential(*f) - - def forward(self, bs): - return BracketedSequence(bs.x + self.f(bs).x, bs.first, bs.nb) - - -############################## - - -class CachedVaswaniPositionalEncoding(nn.Module): - def __init__(self, len_max): - super().__init__() - self.len_max = len_max - - # [Vaswani et al 2018] PE_{t,2i} = sin(t/(L^{2i/D})), PE_{t,2i+1} = cos(t/(L^{2i/D})) - - def forward(self, bs): - if bs.first == 0: - t = torch.arange(bs.x.size(1), dtype=bs.x.dtype, device=bs.x.device)[ - :, None - ] - j = torch.arange(bs.x.size(2), dtype=bs.x.dtype, device=bs.x.device)[ - None, : - ] - k = j % 2 - self.pe = torch.sin( - t / (self.len_max ** ((j - k) / bs.x.size(2))) + math.pi / 2 * k - ) - self.cache_y = bs.x.new(bs.x.size()) - - self.cache_y[:, bs.first : bs.first + bs.nb] = ( - bs.slice() + self.pe[bs.first : bs.first + bs.nb] - ) - - return BracketedSequence(self.cache_y, bs.first, bs.nb) - - -############################## - - -class TrainablePositionalEncoding(nn.Module): - def __init__(self, dim, len_max): - super().__init__() - self.len_max = len_max - self.pe = nn.Parameter(torch.randn(1, len_max, dim) / math.sqrt(dim)) - - def forward(self, bs): - if bs.first == 0: - self.cache_y = bs.x.new(bs.x.size()) - - self.cache_y[:, bs.first : bs.first + bs.nb] = ( - bs.slice() + self.pe[:, bs.first : bs.first + bs.nb, :] - ) - - return BracketedSequence(self.cache_y, bs.first, bs.nb) - - -############################## - - -class EncoderHead(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.fc = nn.Linear(dim_in, dim_out) - - def forward(self, bs): - z = self.fc(bs.x).mean(dim=1) - return z, bs.x.shape - - -class DecoderBottom(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.fc = nn.Linear(dim_in, dim_out) - - def forward(self, z_shape): - z, shape = z_shape - y = self.fc(z)[:, None, :].expand(shape) - return BracketedSequence(y) - - -############################## - - -class QKVAttention(nn.Module): - def __init__( - self, - dim_in, - dim_qk, - dim_v, - nb_heads=1, - compute_attzero=None, - attention_dropout=0.0, - ): - super().__init__() - - def randw(*d): - return nn.Parameter(torch.randn(*d) / math.sqrt(d[-1])) - - self.compute_attzero = compute_attzero - self.attention_dropout = attention_dropout - self.record_attention = False - - self.w_q = randw(nb_heads, dim_qk, dim_in) - self.w_k = randw(nb_heads, dim_qk, dim_in) - self.w_v = randw(nb_heads, dim_v, dim_in) - self.w_o = randw(dim_v * nb_heads, dim_in) - - def forward(self, bs_q, bs_kv=None): - if bs_kv is None: - bs_kv = bs_q - - x_q = bs_q.x - x_kv = bs_kv.x - - if bs_kv.first == 0: - self.cache_k = x_kv.new_zeros( - x_kv.size(0), self.w_k.size(0), x_kv.size(1), self.w_k.size(1) - ) - self.cache_v = x_kv.new_zeros( - x_kv.size(0), self.w_v.size(0), x_kv.size(1), self.w_v.size(1) - ) - - if bs_q.first == 0: - self.cache_y = x_q.new_zeros(x_q.size(0), x_q.size(1), self.w_o.size(1)) - - q = torch.einsum( - "ntc,hdc->nhtd", x_q[:, bs_q.first : bs_q.first + bs_q.nb], self.w_q - ) - - self.cache_k[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum( - "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_k - ) - self.cache_v[:, :, bs_kv.first : bs_kv.first + bs_kv.nb] = torch.einsum( - "ntc,hdc->nhtd", x_kv[:, bs_kv.first : bs_kv.first + bs_kv.nb], self.w_v - ) - - a = torch.einsum( - "nhtd,nhsd->nhts", q, self.cache_k[:, :, : bs_kv.first + bs_kv.nb] - ) / math.sqrt(self.w_q.size(1)) - - if self.compute_attzero is not None: - if bs_q.first == 0: - self.cache_attzero = self.compute_attzero( - torch.arange(x_q.size(1), device=q.device)[:, None], - torch.arange(x_kv.size(1), device=q.device)[None, :], - )[None, None, :, :] - a = a.masked_fill( - self.cache_attzero[ - :, :, bs_q.first : bs_q.first + bs_q.nb, : bs_kv.first + bs_kv.nb - ], - float("-inf"), - ) - - a = a.softmax(dim=3) - - if self.record_attention: - self.a = a - - a = F.dropout(a, self.attention_dropout, self.training) - - y = torch.einsum( - "nhts,nhsd->nthd", a, self.cache_v[:, :, : bs_kv.first + bs_kv.nb] - ).flatten(2) - - self.cache_y[:, bs_q.first : bs_q.first + bs_q.nb] = y @ self.w_o - - return BracketedSequence(self.cache_y, bs_q.first, bs_q.nb) - - -############################## - - -class NoiseInjector(nn.Module): - def __init__(self, identifier=None): - super().__init__() - self.noise_std = 0.0 - self.identifier = identifier - - def forward(self, x): - if self.noise_std > 0: - x = x * ( - 1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long() - ) - return x - - -############################## - - -class BlockSummarizer(nn.Module): - def __init__(self, nb_blocks, nb_tokens, dim_keys, dim_model): - self.nb_blocks = nb_blocks - self.static_q = nn.Parameter(nb_blocks - 1, nb_tokens, dim_keys) - - def compute_block_attzero(t_q, t_k): - block_size = t_q.size(0) - return (t_q // block_size) <= (t_k // block_size) - - self.qkv = QKVAttention( - dim_in=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - compute_attzero=compute_attzero, - attention_dropout=dropout, - ) - - def forward(self, bs): - pass - - -class ShiftByOne(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, bs): - return BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb) - - -class MyGPT(nn.Module): - def __init__( - self, - vocabulary_size, - dim_model, - dim_keys, - dim_hidden, - nb_heads, - nb_blocks, - compute_attzero=None, - dropout=0.0, - len_max=1e5, - ): - super().__init__() - - assert dim_model % nb_heads == 0 - - self.temperature = 1.0 - - self.shifter = ShiftByOne() - - self.embedding = nn.Sequential( - CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)), - ) - - self.positional_encoding = CachedVaswaniPositionalEncoding(len_max) - - trunk_blocks = [] - - for b in range(nb_blocks): - trunk_blocks += [ - CachedWithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - NoiseInjector(identifier=("attention", b)), - ), - QKVAttention( - dim_in=dim_model, - dim_qk=dim_keys, - dim_v=dim_model // nb_heads, - nb_heads=nb_heads, - compute_attzero=compute_attzero, - attention_dropout=dropout, - ), - ), - CachedWithResidual( - CacheWrapper( - nn.LayerNorm((dim_model,)), - NoiseInjector(identifier=("ffw", b)), - nn.Linear(in_features=dim_model, out_features=dim_hidden), - nn.ReLU(), - nn.Linear(in_features=dim_hidden, out_features=dim_model), - nn.Dropout(dropout), - ), - ), - ] - - self.trunk = nn.Sequential(*trunk_blocks) - - self.readout = CacheWrapper( - nn.Linear(in_features=dim_model, out_features=vocabulary_size) - ) - - with torch.no_grad(): - for m in self.modules(): - if isinstance(m, nn.Embedding): - m.weight.normal_(mean=0, std=2e-2) - elif isinstance(m, nn.LayerNorm): - m.bias.zero_() - m.weight.fill_(1.0) - - def forward(self, bs): - for m in self.modules(): - m.loss = 0 - - bs = self.shifter(bs) - bs = self.embedding(bs) - bs = self.positional_encoding(bs) - bs = self.trunk(bs) - bs = self.readout(bs) - bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature - - for m in self.modules(): - self.loss += m.loss - - return bs - - def reset_transformations(self): - self.temperature = 1.0 - for m in self.modules(): - if isinstance(m, NoiseInjector): - m.noise_std = 0.0 - - def set_noise_injection(self, noise_std, identifier=None): - for m in self.modules(): - if isinstance(m, NoiseInjector): - if identifier is None or identifier == m.identifier: - m.noise_std = noise_std - - def record_attention(self, v=True): - for m in self.modules(): - if isinstance(m, QKVAttention): - m.record_attention = v - - def retrieve_attention(self): - a = [] - for m in self.modules(): - if isinstance(m, QKVAttention): - a.append(m.a) - return a - - -###################################################################### - -if __name__ == "__main__": - print("Basic check.") - - vocabulary_size = 3 - x = torch.randint(vocabulary_size, (1, 5)) - - model = MyGPT( - vocabulary_size=vocabulary_size, - dim_model=4, - dim_keys=2, - dim_hidden=2, - nb_heads=2, - nb_blocks=2, - dropout=0.1, - ) - - model.eval() - y1 = model(BracketedSequence(x)).x - y2 = torch.randn_like(y1) - for s in range(x.size(1)): - z = model(BracketedSequence(x, s, 1)) - y2[:, s] = z.slice() - - print(f"error={((y1 - y2).norm() / (y1.norm() + y2.norm())).item()}") - -###################################################################### diff --git a/sky.py b/sky.py deleted file mode 100755 index cc5bd4f..0000000 --- a/sky.py +++ /dev/null @@ -1,364 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, sys, tqdm, os, warnings - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -###################################################################### - -import problem - - -class Sky(problem.Problem): - colors = torch.tensor( - [ - [255, 255, 255], - [255, 0, 0], - [0, 192, 0], - [0, 0, 255], - [255, 192, 0], - [0, 255, 255], - [255, 0, 255], - [192, 255, 192], - [255, 192, 192], - [192, 192, 255], - [192, 192, 192], - ] - ) - - token_background = 0 - first_bird_token = 1 - nb_bird_tokens = colors.size(0) - 1 - - token2char = ( - "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" - ) - - def __init__( - self, - height=6, - width=8, - nb_birds=3, - speed=2, - nb_iterations=2, - avoid_collision=True, - max_nb_cached_chunks=None, - chunk_size=None, - nb_threads=-1, - ): - super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) - self.height = height - self.width = width - self.nb_birds = nb_birds - self.speed = speed - self.nb_iterations = nb_iterations - self.avoid_collision = avoid_collision - - def generate_frame_sequences(self, nb): - frame_sequences = [] - - for _ in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world generation"): - i, j, vi, vj = ( - torch.empty(self.nb_birds, dtype=torch.int64), - torch.empty(self.nb_birds, dtype=torch.int64), - torch.empty(self.nb_birds, dtype=torch.int64), - torch.empty(self.nb_birds, dtype=torch.int64), - ) - - def collision_okay(): - if not self.avoid_collision: - return True - - count = torch.zeros(self.height, self.width, dtype=torch.int64) - - for n in range(self.nb_birds): - count[i[n], j[n]] += 1 - count[i[n] - vi[n], j[n]] += 1 - count[i[n], j[n] - vj[n]] += 1 - - return count.max() <= 1 - - col = ( - torch.randperm(self.colors.size(0) - 1)[: self.nb_birds].sort().values - + 1 - ) - - while True: - while True: - for n in range(self.nb_birds): - while True: - i[n] = torch.randint(self.height, (1,)) - j[n] = torch.randint(self.width, (1,)) - vm = torch.randint(4, (1,)) - vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1 - if ( - i[n] - vi[n] >= 0 - and i[n] - vi[n] < self.height - and j[n] - vj[n] >= 0 - and j[n] - vj[n] < self.width - ): - break - - if collision_okay(): - break - - result = torch.zeros( - self.nb_iterations * self.speed, - self.height, - self.width, - dtype=torch.int64, - ) - - fine = torch.empty(self.nb_iterations * self.speed) - - t_to_keep = ( - torch.arange(self.nb_iterations, device=result.device) * self.speed - ) - - for l in range(self.nb_iterations * self.speed): - fine[l] = collision_okay() - for n in range(self.nb_birds): - c = col[n] - result[l, i[n], j[n]] = c - result[l, i[n] - vi[n], j[n]] = c - result[l, i[n], j[n] - vj[n]] = c - - if (i[n] == 0 and vi[n] == -1) or ( - i[n] == self.height - 1 and vi[n] == 1 - ): - vi[n] = -vi[n] - - if (j[n] == 0 and vj[n] == -1) or ( - j[n] == self.width - 1 and vj[n] == 1 - ): - vj[n] = -vj[n] - - i[n] += vi[n] - j[n] += vj[n] - - result = result[t_to_keep] - fine = fine[t_to_keep] - - if fine[-1]: - break - - frame_sequences.append(result) - - return frame_sequences - - ###################################################################### - - def frame2img(self, x, scale=15): - x = x.reshape(x.size(0), self.height, -1) - m = torch.logical_and( - x >= 0, x < self.first_bird_token + self.nb_bird_tokens - ).long() - x = self.colors[x * m].permute(0, 3, 1, 2) - s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) - x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - - x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 - x[:, :, torch.arange(0, x.size(2), scale), :] = 0 - x = x[:, :, 1:, 1:] - - for n in range(m.size(0)): - for i in range(m.size(1)): - for j in range(m.size(2)): - if m[n, i, j] == 0: - for k in range(2, scale - 2): - for l in [0, 1]: - x[n, :, i * scale + k, j * scale + k - l] = 0 - x[ - n, :, i * scale + scale - 1 - k, j * scale + k - l - ] = 0 - - return x - - def seq2str(self, seq): - result = [] - for s in seq: - result.append("".join([self.token2char[v] for v in s])) - return result - - def save_image( - self, - result_dir, - filename, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - if predicted_prompts is None: - predicted_prompts = 255 - - if predicted_answers is None: - predicted_answers = 255 - - def add_frame(x, c, margin, bottom=False): - if bottom: - h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0 - else: - h, w, di, dj = ( - x.size(2) + 2 * margin, - x.size(3) + 2 * margin, - margin, - margin, - ) - - y = x.new_full((x.size(0), x.size(1), h, w), 0) - - if type(c) is int: - y[...] = c - else: - c = c.long()[:, None] - c = ( - (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) - + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) - + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) - ) - y[...] = c[:, :, None, None] - - y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x - - return y - - margin = 4 - - img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1) - h = img_prompts.size(2) - img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1) - - img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True) - img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True) - - img_prompts = add_frame( - img_prompts, c=predicted_prompts, margin=margin, bottom=True - ) - img_answers = add_frame( - img_answers, c=predicted_answers, margin=margin, bottom=True - ) - - marker_size = 16 - - separator = img_prompts.new_full( - ( - img_prompts.size(0), - img_prompts.size(1), - img_prompts.size(2), - marker_size, - ), - 255, - ) - - separator[:, :, 0] = 0 - separator[:, :, h - 1] = 0 - - for k in range(1, 2 * marker_size - 8): - i = k - (marker_size - 4) - j = marker_size - 5 - abs(i) - separator[:, :, h // 2 - 1 + i, 2 + j] = 0 - separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 - - img = torch.cat([img_prompts, separator, img_answers], dim=3) - - image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image( - img.float() / 255.0, image_name, nrow=6, padding=margin * 4, pad_value=1.0 - ) - - ###################################################################### - - def nb_token_values(self): - return len(self.colors) - - def generate_prompts_and_answers(self, nb): - frame_sequences = self.generate_frame_sequences(nb) - frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) - - prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) - - answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) - - # warnings.warn("dirty test with longer answer", RuntimeWarning) - # answers = torch.cat( - # [ - # frame_sequences[:, frame_sequences.size(1) // 2 :], - # frame_sequences[:, frame_sequences.size(1) // 2 :], - # ], - # dim=3, - # ).flatten(1) - - return prompts, answers - - def save_quiz_illustrations( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - self.save_image( - result_dir, - filename_prefix + ".png", - prompts, - answers, - predicted_prompts, - predicted_answers, - ) - - -###################################################################### - -if __name__ == "__main__": - import time - - sky = Sky(height=6, width=8, speed=1, nb_iterations=4) - - prompts, answers = sky.generate_prompts_and_answers(4) - - predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 - predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 - - sky.save_quiz_illustrations( - "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers - ) - - # start_time = time.perf_counter() - # token_sequences = sky.generate_token_sequences(nb=64) - # delay = time.perf_counter() - start_time - # print(f"{token_sequences.size(0)/delay:02f} seq/s") - - # print(sky.seq2str(seq[:4])) - - # for t in range(len(it[0])): - # img = torch.cat([sky.frame2img(f[t]) for f in it], dim=0) - # torchvision.utils.save_image( - # img.float() / 255.0, - # f"/tmp/frame_{t:03d}.png", - # nrow=8, - # padding=6, - # pad_value=0, - # ) - - # m = (torch.rand(seq.size()) < 0.05).long() - # seq = (1 - m) * seq + m * 23 - - # print(seq.size()) - # img = sky.seq2img(token_sequences) - # print(img.size()) - - # torchvision.utils.save_image( - # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0 - # ) diff --git a/wireworld.py b/wireworld.py deleted file mode 100755 index 8257cad..0000000 --- a/wireworld.py +++ /dev/null @@ -1,357 +0,0 @@ -#!/usr/bin/env python - -# Any copyright is dedicated to the Public Domain. -# https://creativecommons.org/publicdomain/zero/1.0/ - -# Written by Francois Fleuret - -import math, sys, tqdm, os - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - -###################################################################### - -import problem - - -class Wireworld(problem.Problem): - colors = torch.tensor( - [ - [128, 128, 128], - [128, 128, 255], - [255, 0, 0], - [255, 255, 0], - ] - ) - - token_empty = 0 - token_head = 1 - token_tail = 2 - token_conductor = 3 - token_forward = 4 - token_backward = 5 - - token2char = ( - "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" - ) - - def __init__( - self, height=6, width=8, nb_objects=2, nb_walls=2, speed=1, nb_iterations=4 - ): - self.height = height - self.width = width - self.nb_objects = nb_objects - self.nb_walls = nb_walls - self.speed = speed - self.nb_iterations = nb_iterations - - def direction_tokens(self): - return self.token_forward, self.token_backward - - def generate_frame_sequences(self, nb): - result = [] - N = 100 - for _ in tqdm.tqdm( - range(0, nb + N, N), dynamic_ncols=True, desc="world generation" - ): - result.append(self.generate_frame_sequences_hard(100)) - return torch.cat(result, dim=0)[:nb] - - def generate_frame_sequences_hard(self, nb): - frame_sequences = [] - nb_frames = (self.nb_iterations - 1) * self.speed + 1 - - result = torch.full( - (nb * 4, nb_frames, self.height, self.width), - self.token_empty, - ) - - for n in range(result.size(0)): - while True: - i = torch.randint(self.height, (1,)) - j = torch.randint(self.width, (1,)) - v = torch.randint(2, (2,)) - vi = v[0] * (v[1] * 2 - 1) - vj = (1 - v[0]) * (v[1] * 2 - 1) - while True: - if i < 0 or i >= self.height or j < 0 or j >= self.width: - break - o = 0 - if i > 0: - o += (result[n, 0, i - 1, j] == self.token_conductor).long() - if i < self.height - 1: - o += (result[n, 0, i + 1, j] == self.token_conductor).long() - if j > 0: - o += (result[n, 0, i, j - 1] == self.token_conductor).long() - if j < self.width - 1: - o += (result[n, 0, i, j + 1] == self.token_conductor).long() - if o > 1: - break - result[n, 0, i, j] = self.token_conductor - i += vi - j += vj - if ( - result[n, 0] == self.token_conductor - ).long().sum() > self.width and torch.rand(1) < 0.5: - break - - while True: - for _ in range(self.height * self.width): - i = torch.randint(self.height, (1,)) - j = torch.randint(self.width, (1,)) - v = torch.randint(2, (2,)) - vi = v[0] * (v[1] * 2 - 1) - vj = (1 - v[0]) * (v[1] * 2 - 1) - if ( - i + vi >= 0 - and i + vi < self.height - and j + vj >= 0 - and j + vj < self.width - and result[n, 0, i, j] == self.token_conductor - and result[n, 0, i + vi, j + vj] == self.token_conductor - ): - result[n, 0, i, j] = self.token_head - result[n, 0, i + vi, j + vj] = self.token_tail - break - - # if torch.rand(1) < 0.75: - break - - weight = torch.full((1, 1, 3, 3), 1.0) - - mask = (torch.rand(result[:, 0].size()) < 0.01).long() - rand = torch.randint(4, mask.size()) - result[:, 0] = mask * rand + (1 - mask) * result[:, 0] - - # empty->empty - # head->tail - # tail->conductor - # conductor->head if 1 or 2 head in the neighborhood, or remains conductor - - nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1) - valid = nb_heads > 0 - - for l in range(nb_frames - 1): - nb_head_neighbors = ( - F.conv2d( - input=(result[:, l] == self.token_head).float()[:, None, :, :], - weight=weight, - padding=1, - ) - .long() - .squeeze(1) - ) - mask_1_or_2_heads = (nb_head_neighbors == 1).long() + ( - nb_head_neighbors == 2 - ).long() - result[:, l + 1] = ( - (result[:, l] == self.token_empty).long() * self.token_empty - + (result[:, l] == self.token_head).long() * self.token_tail - + (result[:, l] == self.token_tail).long() * self.token_conductor - + (result[:, l] == self.token_conductor).long() - * ( - mask_1_or_2_heads * self.token_head - + (1 - mask_1_or_2_heads) * self.token_conductor - ) - ) - pred_nb_heads = nb_heads - nb_heads = ( - (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1) - ) - valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads)) - - result = result[valid] - - result = result[ - :, torch.arange(self.nb_iterations, device=result.device) * self.speed - ] - - i = (result[:, -1] == self.token_head).flatten(1).max(dim=1).values > 0 - result = result[i] - - # print(f"{result.size(0)=} {nb=}") - - if result.size(0) < nb: - # print(result.size(0)) - result = torch.cat( - [result, self.generate_frame_sequences(nb - result.size(0))], dim=0 - ) - - return result[:nb] - - def generate_token_sequences(self, nb): - frame_sequences = self.generate_frame_sequences(nb) - - result = [] - - for frame_sequence in frame_sequences: - a = [] - if torch.rand(1) < 0.5: - for frame in frame_sequence: - if len(a) > 0: - a.append(torch.tensor([self.token_forward])) - a.append(frame.flatten()) - else: - for frame in reversed(frame_sequence): - if len(a) > 0: - a.append(torch.tensor([self.token_backward])) - a.append(frame.flatten()) - - result.append(torch.cat(a, dim=0)[None, :]) - - return torch.cat(result, dim=0) - - ###################################################################### - - def frame2img(self, x, scale=15): - x = x.reshape(-1, self.height, self.width) - m = torch.logical_and(x >= 0, x < 4).long() - - x = self.colors[x * m].permute(0, 3, 1, 2) - s = x.shape - x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) - x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) - - x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 - x[:, :, torch.arange(0, x.size(2), scale), :] = 0 - x = x[:, :, 1:, 1:] - - for n in range(m.size(0)): - for i in range(m.size(1)): - for j in range(m.size(2)): - if m[n, i, j] == 0: - for k in range(2, scale - 2): - for l in [0, 1]: - x[n, :, i * scale + k, j * scale + k - l] = 0 - x[ - n, :, i * scale + scale - 1 - k, j * scale + k - l - ] = 0 - - return x - - def seq2img(self, seq, scale=15): - all = [ - self.frame2img( - seq[:, : self.height * self.width].reshape(-1, self.height, self.width), - scale, - ) - ] - - separator = torch.full((seq.size(0), 3, self.height * scale - 1, 1), 0) - - t = self.height * self.width - - while t < seq.size(1): - direction_tokens = seq[:, t] - t += 1 - - direction_images = self.colors[ - torch.full( - (direction_tokens.size(0), self.height * scale - 1, scale), 0 - ) - ].permute(0, 3, 1, 2) - - for n in range(direction_tokens.size(0)): - if direction_tokens[n] == self.token_forward: - for k in range(scale): - for l in [0, 1]: - direction_images[ - n, - :, - (self.height * scale) // 2 - scale // 2 + k - l, - 3 + scale // 2 - abs(k - scale // 2), - ] = 0 - elif direction_tokens[n] == self.token_backward: - for k in range(scale): - for l in [0, 1]: - direction_images[ - n, - :, - (self.height * scale) // 2 - scale // 2 + k - l, - 3 + abs(k - scale // 2), - ] = 0 - else: - for k in range(2, scale - 2): - for l in [0, 1]: - direction_images[ - n, - :, - (self.height * scale) // 2 - scale // 2 + k - l, - k, - ] = 0 - direction_images[ - n, - :, - (self.height * scale) // 2 - scale // 2 + k - l, - scale - 1 - k, - ] = 0 - - all += [ - separator, - direction_images, - separator, - self.frame2img( - seq[:, t : t + self.height * self.width].reshape( - -1, self.height, self.width - ), - scale, - ), - ] - - t += self.height * self.width - - return torch.cat(all, dim=3) - - def seq2str(self, seq): - result = [] - for s in seq: - result.append("".join([self.token2char[v] for v in s])) - return result - - def save_image(self, input, result_dir, filename): - img = self.seq2img(input.to("cpu")) - image_name = os.path.join(result_dir, filename) - torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4) - - def save_quizzes(self, input, result_dir, filename_prefix): - self.save_image(input, result_dir, filename_prefix + ".png") - - -###################################################################### - -if __name__ == "__main__": - import time - - wireworld = Wireworld(height=8, width=10, nb_iterations=5, speed=1) - - start_time = time.perf_counter() - frame_sequences = wireworld.generate_frame_sequences(nb=96) - delay = time.perf_counter() - start_time - print(f"{frame_sequences.size(0)/delay:02f} seq/s") - - # print(wireworld.seq2str(seq[:4])) - - for t in range(frame_sequences.size(1)): - img = wireworld.seq2img(frame_sequences[:, t]) - torchvision.utils.save_image( - img.float() / 255.0, - f"/tmp/frame_{t:03d}.png", - nrow=8, - padding=6, - pad_value=0, - ) - - # m = (torch.rand(seq.size()) < 0.05).long() - # seq = (1 - m) * seq + m * 23 - - wireworld = Wireworld(height=8, width=10, nb_iterations=2, speed=5) - token_sequences = wireworld.generate_token_sequences(32) - wireworld.save_quizzes(token_sequences, "/tmp", "seq") - # img = wireworld.seq2img(frame_sequences[:60]) - - # torchvision.utils.save_image( - # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=10, pad_value=0.1 - # )