From: Francois Fleuret Date: Fri, 12 Aug 2022 07:57:09 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=f27d6083fbe7243f5896ddd49587fe1923fe9a79;p=pytorch.git Update. --- diff --git a/minidiffusion.py b/minidiffusion.py index ad1cda0..037ef11 100755 --- a/minidiffusion.py +++ b/minidiffusion.py @@ -5,6 +5,11 @@ # Written by Francois Fleuret +# Minimal implementation of Jonathan Ho, Ajay Jain, Pieter Abbeel +# "Denoising Diffusion Probabilistic Models" (2020) +# +# https://arxiv.org/abs/2006.11239 + import matplotlib.pyplot as plt import torch from torch import nn @@ -62,7 +67,7 @@ for k in range(nb_epochs): if k%10 == 0: print(k, loss.item()) ###################################################################### -# Plot +# Generate x = torch.randn(10000, 1) @@ -71,19 +76,27 @@ for t in range(T-1, -1, -1): input = torch.cat((x, torch.ones(x.size(0), 1) * 2 * t / T - 1), 1) x = 1 / alpha[t].sqrt() * (x - (1 - alpha[t])/(1 - alpha_bar[t]).sqrt() * model(input)) + sigma[t] * z +###################################################################### +# Plot + fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.set_xlim(-1.25, 1.25) d = train_input.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), histtype = 'stepfilled', color = 'lightblue', density = True, label = 'Train') +ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'stepfilled', color = 'lightblue', label = 'Train') d = x.flatten().detach().numpy() -ax.hist(d, 25, (-1, 1), histtype = 'step', color = 'red', density = True, label = 'Synthesis') +ax.hist(d, 25, (-1, 1), + density = True, + histtype = 'step', color = 'red', label = 'Synthesis') ax.legend(frameon = False, loc = 2) filename = 'diffusion.pdf' +print(f'saving {filename}') fig.savefig(filename, bbox_inches='tight') plt.show()