From 226589286bd8701002102062394909a82f5e807e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 3 Jul 2024 15:57:30 +0300 Subject: [PATCH] Update. --- lang.py | 243 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100755 lang.py diff --git a/lang.py b/lang.py new file mode 100755 index 0000000..d53386c --- /dev/null +++ b/lang.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, sys, tqdm, os, warnings + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + +import problem + + +class Lang(problem.Problem): + named_colors = [ + ("white", [255, 255, 255]), + ("red", [255, 0, 0]), + ("green", [0, 192, 0]), + ("blue", [0, 0, 255]), + ("orange", [255, 192, 0]), + ("cyan", [0, 255, 255]), + ("violet", [255, 0, 255]), + ("lightgreen", [192, 255, 192]), + ("pink", [255, 192, 192]), + ("lightblue", [192, 192, 255]), + ("gray", [192, 192, 192]), + ] + + def __init__( + self, + nb_iterations=2, + ): + self.colors = torch.tensor([c for _, c in self.named_colors]) + self.name2color = dict([(p[0], i) for i, p in enumerate(self.named_colors)]) + self.height = 10 + self.width = 10 + self.nb_iterations = nb_iterations + + ###################################################################### + + def frame2img(self, x, scale=15): + x = x.reshape(x.size(0), self.height, -1) + x = self.colors[x].permute(0, 3, 1, 2) + s = x.shape + x = x[:, :, :, None, :, None].expand(-1, -1, -1, scale, -1, scale) + x = x.reshape(s[0], s[1], s[2] * scale, s[3] * scale) + + x[:, :, :, torch.arange(0, x.size(3), scale)] = 0 + x[:, :, torch.arange(0, x.size(2), scale), :] = 0 + x = x[:, :, 1:, 1:] + + return x + + def save_image( + self, + result_dir, + filename, + prompts, + answers, + predicted_prompts=None, + predicted_answers=None, + ): + if predicted_prompts is None: + predicted_prompts = 255 + + if predicted_answers is None: + predicted_answers = 255 + + def add_frame(x, c, margin, bottom=False): + if bottom: + h, w, di, dj = x.size(2) + margin, x.size(3), 0, 0 + else: + h, w, di, dj = ( + x.size(2) + 2 * margin, + x.size(3) + 2 * margin, + margin, + margin, + ) + + y = x.new_full((x.size(0), x.size(1), h, w), 0) + + if type(c) is int: + y[...] = c + else: + c = c.long()[:, None] + c = c * torch.tensor([0, 0, 0], device=c.device) + ( + 1 - c + ) * torch.tensor([255, 255, 255], device=c.device) + y[...] = c[:, :, None, None] + + y[:, :, di : di + x.size(2), dj : dj + x.size(3)] = x + + return y + + margin = 4 + + img_prompts = add_frame(self.frame2img(prompts.to("cpu")), c=0, margin=1) + h = img_prompts.size(2) + img_answers = add_frame(self.frame2img(answers.to("cpu")), c=0, margin=1) + + img_prompts = add_frame(img_prompts, c=255, margin=margin, bottom=True) + img_answers = add_frame(img_answers, c=255, margin=margin, bottom=True) + + img_prompts = add_frame( + img_prompts, c=predicted_prompts, margin=margin, bottom=True + ) + img_answers = add_frame( + img_answers, c=predicted_answers, margin=margin, bottom=True + ) + + marker_size = 16 + + separator = img_prompts.new_full( + ( + img_prompts.size(0), + img_prompts.size(1), + img_prompts.size(2), + marker_size, + ), + 255, + ) + + separator[:, :, 0] = 0 + separator[:, :, h - 1] = 0 + + for k in range(1, 2 * marker_size - 8): + i = k - (marker_size - 4) + j = marker_size - 5 - abs(i) + separator[:, :, h // 2 - 1 + i, 2 + j] = 0 + separator[:, :, h // 2 - 1 + i + 1, 2 + j] = 0 + + img = torch.cat([img_prompts, separator, img_answers], dim=3) + + image_name = os.path.join(result_dir, filename) + torchvision.utils.save_image( + img.float() / 255.0, image_name, nrow=4, padding=margin * 4, pad_value=1.0 + ) + + ###################################################################### + + def nb_token_values(self): + return len(self.colors) + + def rec_coo(self, x): + while True: + i1, i2 = torch.randint(x.size(0), (2,)) + if i1 < i2 - 1: + break + while True: + j1, j2 = torch.randint(x.size(1), (2,)) + if j1 < j2 - 1: + break + return i1, j1, i2, j2 + + def task_red_to_green(self, A, f_A, B, f_B): + i1, j1, i2, j2 = self.rec_coo(A) + A[i1:i2, j1:j2] = self.name2color["red"] + f_A[i1:i2, j1:j2] = self.name2color["green"] + i1, j1, i2, j2 = self.rec_coo(B) + B[i1:i2, j1:j2] = self.name2color["red"] + f_B[i1:i2, j1:j2] = self.name2color["green"] + + def generate_prompts_and_answers(self, nb): + prompts = torch.zeros(nb, self.height, self.width * 3, dtype=torch.int64) + answers = torch.zeros(nb, self.height, self.width, dtype=torch.int64) + w = self.width + for prompt, answer in zip(prompts, answers): + self.task_red_to_green( + prompt[:, 0 * w : 1 * w], + prompt[:, 1 * w : 2 * w], + prompt[:, 2 * w : 3 * w], + answer, + ) + return prompts, answers + + def save_quizzes( + self, + result_dir, + filename_prefix, + prompts, + answers, + predicted_prompts=None, + predicted_answers=None, + ): + self.save_image( + result_dir, + filename_prefix + ".png", + prompts, + answers, + predicted_prompts, + predicted_answers, + ) + + +###################################################################### + +if __name__ == "__main__": + import time + + lang = Lang(nb_iterations=4) + + prompts, answers = lang.generate_prompts_and_answers(24) + + # predicted_prompts = torch.rand(prompts.size(0)) < 0.5 + # predicted_answers = torch.rand(answers.size(0)) < 0.5 + + lang.save_quizzes( + "/tmp", "test", prompts, answers # , predicted_prompts, predicted_answers + ) + + # start_time = time.perf_counter() + # token_sequences = lang.generate_token_sequences(nb=64) + # delay = time.perf_counter() - start_time + # print(f"{token_sequences.size(0)/delay:02f} seq/s") + + # print(lang.seq2str(seq[:4])) + + # for t in range(len(it[0])): + # img = torch.cat([lang.frame2img(f[t]) for f in it], dim=0) + # torchvision.utils.save_image( + # img.float() / 255.0, + # f"/tmp/frame_{t:03d}.png", + # nrow=8, + # padding=6, + # pad_value=0, + # ) + + # m = (torch.rand(seq.size()) < 0.05).long() + # seq = (1 - m) * seq + m * 23 + + # print(seq.size()) + # img = lang.seq2img(token_sequences) + # print(img.size()) + + # torchvision.utils.save_image( + # img.float() / 255.0, "/tmp/world.png", nrow=6, padding=6, pad_value=0 + # ) -- 2.39.5