From 756b8e59a361755f88917aaf4bc214c5f6cce94b Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Tue, 9 Jan 2018 15:50:23 +0100 Subject: [PATCH] Added a script to general samples as images in a pytorch-structured ImageFolder. --- generate.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100755 generate.py diff --git a/generate.py b/generate.py new file mode 100755 index 0000000..7f29683 --- /dev/null +++ b/generate.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python + +# svrt is the ``Synthetic Visual Reasoning Test'', an image +# generator for evaluating classification performance of machine +# learning systems, humans and primates. +# +# Copyright (c) 2017 Idiap Research Institute, http://www.idiap.ch/ +# Written by Francois Fleuret +# +# This file is part of svrt. +# +# svrt is free software: you can redistribute it and/or modify it +# under the terms of the GNU General Public License version 3 as +# published by the Free Software Foundation. +# +# svrt is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with svrt. If not, see . + +import time +import argparse + +import torch +import torchvision, os + +from torch import optim +from torch import FloatTensor as Tensor +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as fn + +from torchvision import datasets, transforms, utils + +import svrt + +###################################################################### +# Parsing arguments +###################################################################### + +parser = argparse.ArgumentParser( + description='SVRT sample generator.', + formatter_class = argparse.ArgumentDefaultsHelpFormatter +) + +parser.add_argument('--nb_samples', + type = int, + default = 1000, + help='How many samples to generate') + +parser.add_argument('--problem', + type = int, + default = 1, + help='Problem to generate samples from') + +parser.add_argument('--data_dir', + type = str, + default = '', + help='Where to generate the samples') + +###################################################################### + +args = parser.parse_args() + +if os.path.isdir(args.data_dir): + name = 'problem_{:02d}/class_'.format(args.problem) + os.makedirs(args.data_dir + '/' + name + '0', exist_ok = True) + os.makedirs(args.data_dir + '/' + name + '1', exist_ok = True) +else: + raise FileNotFoundError('Cannot find ' + args.data_dir) + +labels = torch.LongTensor(args.nb_samples).zero_() +labels.narrow(0, 0, labels.size(0)//2).fill_(1) +x = svrt.generate_vignettes(args.problem, labels).float() + +x.sub_(128).div_(64) + +print('MEAN', x.mean(), 'STD', x.std()) + +for k in range(x.size(0)): + filename = args.data_dir + '/problem_{:02d}/class_{:d}/img_{:06d}.png'.format(args.problem, labels[k], k) + torchvision.utils.save_image(x[k].view(1, x.size(1), x.size(2)), filename) -- 2.39.5