From: Francois Fleuret Date: Wed, 18 Dec 2019 15:51:34 +0000 (+0100) Subject: Initial commit. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0d0635ed4e6836ef2c48cd59fe3d25f7969e7bcf;p=pytorch.git Initial commit. --- diff --git a/denoising-ae-field.py b/denoising-ae-field.py new file mode 100755 index 0000000..175f344 --- /dev/null +++ b/denoising-ae-field.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +model = nn.Sequential( + nn.Linear(2, 100), + nn.ReLU(), + nn.Linear(100, 2) +) + +############################################################ + +def data_zigzag(nb): + a = torch.empty(nb).uniform_(0, 1).view(-1, 1) + # zigzag + x = 0.4 * ((a-0.5) * 5 * math.pi).cos() + y = a * 2.5 - 1.25 + data = torch.cat((y, x), 1) + data = data @ torch.tensor([[1., -1.], [1., 1.]]) + return data + +def data_spiral(nb): + a = torch.empty(nb).uniform_(0, 1).view(-1, 1) + x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5) + y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5) + data = torch.cat((y, x), 1) + return data + +###################################################################### + +data = data_spiral(1000) +# data = data_zigzag(1000) + +data = data - data.mean(0) + +batch_size, nb_epochs = 100, 1000 +optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) +criterion = nn.MSELoss() + +for e in range(nb_epochs): + acc_loss = 0 + for input in data.split(batch_size): + noise = input.new(input.size()).normal_(0, 0.1) + output = model(input + noise) + loss = criterion(output, input) + acc_loss += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + if (e+1)%10 == 0: print(e+1, acc_loss) + +###################################################################### + +a = torch.linspace(-1.5, 1.5, 30) +x = a.view( 1, -1, 1).expand(a.size(0), a.size(0), 1) +y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1) +grid = torch.cat((y, x), 2).view(-1, 2) + +# Take the origins of the arrows on the part of grid closer than 0.1 +# from the data points +dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0] +origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)] + +field = model(origins).detach() - origins + +###################################################################### + +import matplotlib.pyplot as plt + +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1) + +ax.axis('off') +ax.set_xlim(-1.6, 1.6) +ax.set_ylim(-1.6, 1.6) +ax.set_aspect(1) + +plot_field = ax.quiver(origins[:, 0].numpy(), origins[:, 1].numpy(), + field[:, 0].numpy(), field[:, 1].numpy(), + units = 'xy', scale = 1, + width = 3e-3, headwidth = 25, headlength = 25) + +plot_data = ax.scatter(data[:, 0].numpy(), data[:, 1].numpy(), s = 1, color = 'tab:blue') + +fig.savefig('denoising_field.pdf', bbox_inches='tight') + +######################################################################