From a4bc783e87679b297f544433b4a7f005c1e115a9 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 22 Jun 2020 09:48:26 +0200 Subject: [PATCH] Cleaning up the code a bit. --- ddpol.py | 87 +++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/ddpol.py b/ddpol.py index 6812fdf..51e7636 100755 --- a/ddpol.py +++ b/ddpol.py @@ -5,14 +5,35 @@ # Written by Francois Fleuret -import math +import math, argparse import matplotlib.pyplot as plt + import torch -nb_train_samples = 8 -D_max = 16 -nb_runs = 250 -train_noise_std = 0 +###################################################################### + +parser = argparse.ArgumentParser(description='Example of double descent with polynomial regression.') + +parser.add_argument('--D-max', + type = int, default = 16) + +parser.add_argument('--nb-runs', + type = int, default = 250) + +parser.add_argument('--nb-train-samples', + type = int, default = 8) + +parser.add_argument('--train-noise-std', + type = float, default = 0.) + +parser.add_argument('--seed', + type = int, default = 0, + help = 'Random seed (default 0, < 0 is no seeding)') + +args = parser.parse_args() + +if args.seed >= 0: + torch.manual_seed(args.seed) ###################################################################### @@ -30,7 +51,7 @@ def fit_alpha(x, y, D, a = 0, b = 1, rho = 1e-12): beta = x.new_zeros(D + 1, D + 1) beta[2:, 2:] = (q-1) * q * (r-1) * r * (b**(q+r-3) - a**(q+r-3))/(q+r-3) l, U = beta.eig(eigenvectors = True) - Q = U @ torch.diag(l[:, 0].pow(0.5)) + Q = U @ torch.diag(l[:, 0].clamp(min = 0) ** 0.5) B = torch.cat((B, y.new_zeros(Q.size(0))), 0) M = torch.cat((M, math.sqrt(rho) * Q.t()), 0) @@ -43,26 +64,30 @@ def phi(x): ###################################################################### -torch.manual_seed(0) +def compute_mse(nb_train_samples): + mse_train = torch.zeros(args.nb_runs, args.D_max + 1) + mse_test = torch.zeros(args.nb_runs, args.D_max + 1) + + for k in range(args.nb_runs): + x_train = torch.rand(nb_train_samples, dtype = torch.float64) + y_train = phi(x_train) + if args.train_noise_std > 0: + y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std) + x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) + y_test = phi(x_test) + + for D in range(args.D_max + 1): + alpha = fit_alpha(x_train, y_train, D) + mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean() + mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean() -mse_train = torch.zeros(nb_runs, D_max + 1) -mse_test = torch.zeros(nb_runs, D_max + 1) + return mse_train.median(0).values, mse_test.median(0).values -for k in range(nb_runs): - x_train = torch.rand(nb_train_samples, dtype = torch.float64) - y_train = phi(x_train) - if train_noise_std > 0: - y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std) - x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) - y_test = phi(x_test) +###################################################################### - for D in range(D_max + 1): - alpha = fit_alpha(x_train, y_train, D) - mse_train[k, D] = ((pol_value(alpha, x_train) - y_train)**2).mean() - mse_test[k, D] = ((pol_value(alpha, x_test) - y_test)**2).mean() +torch.manual_seed(0) -mse_train = mse_train.median(0).values -mse_test = mse_test.median(0).values +mse_train, mse_test = compute_mse(args.nb_train_samples) ###################################################################### # Plot the MSE vs. degree curves @@ -75,27 +100,29 @@ ax.set_ylim(1e-5, 1) ax.set_xlabel('Polynomial degree', labelpad = 10) ax.set_ylabel('MSE', labelpad = 10) -ax.axvline(x = nb_train_samples - 1, color = 'gray', linewidth = 0.5) -ax.plot(torch.arange(D_max + 1), mse_train, color = 'blue', label = 'Train error') -ax.plot(torch.arange(D_max + 1), mse_test, color = 'red', label = 'Test error') +ax.axvline(x = args.nb_train_samples - 1, color = 'gray', linewidth = 0.5) +ax.plot(torch.arange(args.D_max + 1), mse_train, color = 'blue', label = 'Train error') +ax.plot(torch.arange(args.D_max + 1), mse_test, color = 'red', label = 'Test error') ax.legend(frameon = False) fig.savefig('dd-mse.pdf', bbox_inches='tight') +plt.close(fig) + ###################################################################### # Plot some examples of train / test torch.manual_seed(9) # I picked that for pretty -x_train = torch.rand(nb_train_samples, dtype = torch.float64) +x_train = torch.rand(args.nb_train_samples, dtype = torch.float64) y_train = phi(x_train) -if train_noise_std > 0: - y_train = y_train + torch.empty_like(y_train).normal_(0, train_noise_std) +if args.train_noise_std > 0: + y_train = y_train + torch.empty_like(y_train).normal_(0, args.train_noise_std) x_test = torch.linspace(0, 1, 100, dtype = x_train.dtype) y_test = phi(x_test) -for D in range(D_max + 1): +for D in range(args.D_max + 1): fig = plt.figure() ax = fig.add_subplot(1, 1, 1) @@ -111,4 +138,6 @@ for D in range(D_max + 1): fig.savefig(f'dd-example-{D:02d}.pdf', bbox_inches='tight') + plt.close(fig) + ###################################################################### -- 2.39.5