From d18a4193221e54d3d6235b62fc173b73ff7481bf Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Wed, 18 Sep 2024 09:28:16 +0200 Subject: [PATCH] Update. --- diffusion.py | 160 --------------------------------------------------- main.py | 5 -- 2 files changed, 165 deletions(-) delete mode 100755 diffusion.py diff --git a/diffusion.py b/diffusion.py deleted file mode 100755 index 629113a..0000000 --- a/diffusion.py +++ /dev/null @@ -1,160 +0,0 @@ -#!/usr/bin/env python - -import math - -import torch, torchvision - -from torch import nn -from torch.nn import functional as F - - -def NTC_channel_cat(*x): - return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2) - - -class Diffuser: - def __init__(self, mu_T_sampler, nb_iterations, proba_corruption): - self.mu_T_sampler = mu_T_sampler - self.nb_iterations = nb_iterations - self.proba_corruption = proba_corruption - - def sample_x_t_given_x_0(self, x_0, t): - noise = self.mu_T_sampler(x_0.size(), device=x_0.device) - r = torch.rand(x_0.size(), device=x_0.device) - proba_erased = 1 - (1 - self.proba_corruption) ** t - mask_erased = (r <= proba_erased[:, None]).long() - x_t = (1 - mask_erased) * x_0 + mask_erased * noise - - return x_t - - # This function returns a 2d tensor of same shape as low, full of - # uniform random values in [0,1], such that, in every row, the values - # corresponding to the True in low are all lesser than the values - # corresponding to the False. - - def prioritized_rand(self, low): - x = ( - torch.rand(low.size(), device=low.device) - .sort(dim=1, descending=True) - .values - ) - k = torch.rand(low.size(), device=low.device) + low.long() - k = k.sort(dim=1).indices - y = x.new(x.size()) - y.scatter_(dim=1, index=k, src=x) - return y - - def sample_x_t_minus_1_given_x_0_x_t(self, x_0, x_t): - r = self.prioritized_rand(x_0 != x_t) - mask_changes = (r <= self.proba_corruption).long() - x_t_minus_1 = (1 - mask_changes) * x_t + mask_changes * x_0 - return x_t_minus_1 - - ###################################################################### - - def make_mask_hints(self, mask_generate, nb_hints): - if nb_hints is None: - mask_hints = torch.zeros( - mask_generate.size(), - device=mask_generate.device, - dtype=mask_generate.dtype, - ) - else: - u = ( - torch.rand(mask_generate.size(), device=mask_generate.device) - * mask_generate - ) - v = u.sort(dim=1, descending=True).values.gather( - dim=1, index=nb_hints[:, None] - ) - mask_hints = (u > v).long() - - return mask_hints - - # This function gets a clean target x_0, and a mask indicating which - # part to generate (conditionnaly to the others), and returns the - # logits starting from a x_t|X_0=x_0 picked at random with t random - - def logits_hat_x_0_from_random_iteration( - self, model, x_0, mask_generate, nb_hints=None, prompt_noise=0.0 - ): - noise = self.mu_T_sampler(x_0.size(), device=x_0.device) - - single_iteration = ( - mask_generate.sum(dim=1) < mask_generate.size(1) // 2 - ).long()[:, None] - - mask_hints = self.make_mask_hints(mask_generate, nb_hints) * single_iteration - - # We favor iterations near the clean signal - - probs_iterations = 0.1 ** torch.linspace( - 0, 1, self.nb_iterations, device=x_0.device - ) - - probs_iterations = probs_iterations[None, :] / probs_iterations.sum() - probs_iterations = probs_iterations.expand(x_0.size(0), -1) - dist = torch.distributions.categorical.Categorical(probs=probs_iterations) - - t = dist.sample() + 1 - - x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise - x_t = self.sample_x_t_given_x_0(x_0, t) - x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t - x_t = (1 - mask_generate) * x_0 + mask_generate * x_t - - # We may inject noise to prevent high-complexity non-structure - # signal to be generated as a way of "increasing reasoning - # complexity" - - if prompt_noise > 0: - mask_prompt_noise = ( - torch.rand(x_t.size(), device=x_t.device) <= prompt_noise - ).long() - noise = self.mu_T_sampler(x_t.size(), device=x_t.device) - noisy_x_t = (1 - mask_prompt_noise) * x_t + mask_prompt_noise * noise - x_t = (1 - mask_generate) * noisy_x_t + mask_generate * x_t - - x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits_hat_x_0 = model(x_t_with_mask) - - return logits_hat_x_0 - - ###################################################################### - - def generate(self, model, x_0, mask_generate, nb_hints=None): - noise = self.mu_T_sampler(x_0.size(), device=x_0.device) - - single_iteration = ( - mask_generate.sum(dim=1) < mask_generate.size(1) // 2 - ).long()[:, None] - - mask_hints = self.make_mask_hints(mask_generate, nb_hints) - - x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise - x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * noise - x_t = (1 - mask_generate) * x_0 + mask_generate * x_t - - changed = True - - for it in range(self.nb_iterations): - x_t_with_mask = NTC_channel_cat(x_t, mask_generate) - with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): - logits = model(x_t_with_mask) - dist = torch.distributions.categorical.Categorical(logits=logits) - - hat_x_0 = (1 - mask_generate) * x_0 + mask_generate * dist.sample() - - hat_x_t_minus_1 = single_iteration * hat_x_0 + ( - 1 - single_iteration - ) * self.sample_x_t_minus_1_given_x_0_x_t(hat_x_0, x_t) - - if hat_x_t_minus_1.equal(x_t): - break - else: - changed = changed & (hat_x_t_minus_1 != x_t).max(dim=1).values - x_t[changed] = hat_x_t_minus_1[changed] - - return x_t diff --git a/main.py b/main.py index 380be1e..772ef9f 100755 --- a/main.py +++ b/main.py @@ -3,9 +3,6 @@ # Any copyright is dedicated to the Public Domain. # https://creativecommons.org/publicdomain/zero/1.0/ -# > A > f(A) > B ; > f(B) -# < f(B) ; < B < f(A) < A - # Written by Francois Fleuret import math, sys, argparse, time, tqdm, os, datetime, warnings, copy @@ -29,8 +26,6 @@ import threading, subprocess # torch.set_default_dtype(torch.bfloat16) -import diffusion - ###################################################################### parser = argparse.ArgumentParser( -- 2.39.5