From: Francois Fleuret Date: Thu, 22 Jun 2017 06:05:25 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=1b7eb64f1a3de3761ff887b4cfbc25a81a60b00e;p=pysvrt.git Update. --- diff --git a/cnn-svrt.py b/cnn-svrt.py index a41d42c..d6c7169 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -32,6 +32,7 @@ from colorama import Fore, Back, Style # Pytorch import torch +import torchvision from torch import optim from torch import FloatTensor as Tensor @@ -73,6 +74,9 @@ parser.add_argument('--batch_size', parser.add_argument('--log_file', type = str, default = 'default.log') +parser.add_argument('--nb_exemplar_vignettes', + type = int, default = -1) + parser.add_argument('--compress_vignettes', type = distutils.util.strtobool, default = 'True', help = 'Use lossless compression to reduce the memory footprint') @@ -295,6 +299,21 @@ class vignette_logger(): ) self.last_t = t +def save_examplar_vignettes(data_set, nb, name): + n = torch.randperm(data_set.nb_samples).narrow(0, 0, nb) + + for k in range(0, nb): + b = n[k] // data_set.batch_size + m = n[k] % data_set.batch_size + i, t = data_set.get_batch(b) + i = i[m].float() + i.sub_(i.min()) + i.div_(i.max()) + if k == 0: patchwork = Tensor(nb, 1, i.size(1), i.size(2)) + patchwork[k].copy_(i) + + torchvision.utils.save_image(patchwork, name) + ###################################################################### if args.nb_train_samples%args.batch_size > 0 or args.nb_test_samples%args.batch_size > 0: @@ -357,6 +376,10 @@ for problem_number in map(int, args.problems.split(',')): train_set.nb_samples / (time.time() - t)) ) + if args.nb_exemplar_vignettes > 0: + save_examplar_vignettes(train_set, args.nb_exemplar_vignettes, + 'examplar_{:d}.png'.format(problem_number)) + if args.validation_error_threshold > 0.0: validation_set = VignetteSet(problem_number, args.nb_validation_samples, args.batch_size,