From: Francois Fleuret Date: Mon, 26 Jun 2017 13:40:52 +0000 (+0200) Subject: Added --save_test_mistakes. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=ffe0b4fed11bb356684d9faa1849c86997a3029a;p=pysvrt.git Added --save_test_mistakes. --- diff --git a/cnn-svrt.py b/cnn-svrt.py index ade87ce..3fe50d8 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -85,6 +85,9 @@ parser.add_argument('--compress_vignettes', type = distutils.util.strtobool, default = 'True', help = 'Use lossless compression to reduce the memory footprint') +parser.add_argument('--save_test_mistakes', + type = distutils.util.strtobool, default = 'False') + parser.add_argument('--model', type = str, default = 'deepnet', help = 'What model to use') @@ -338,7 +341,7 @@ class DeepNet3(nn.Module): ###################################################################### -def nb_errors(model, data_set): +def nb_errors(model, data_set, mistake_filename_pattern = None): ne = 0 for b in range(0, data_set.nb_batches): input, target = data_set.get_batch(b) @@ -348,6 +351,12 @@ def nb_errors(model, data_set): for i in range(0, data_set.batch_size): if wta_prediction[i] != target[i]: ne = ne + 1 + if mistake_filename_pattern is not None: + img = input[i].clone() + img.sub_(img.min()) + img.div_(img.max()) + torchvision.utils.save_image(img, + mistake_filename_pattern.format(b + i, target[i])) return ne @@ -550,7 +559,8 @@ for problem_number in map(int, args.problems.split(',')): args.nb_test_samples, args.batch_size, cuda = torch.cuda.is_available()) - nb_test_errors = nb_errors(model, test_set) + nb_test_errors = nb_errors(model, test_set, + mistake_filename_pattern = 'mistake_{:d}_{:06d}.png') log_string('test_error {:d} {:.02f}% {:d} {:d}'.format( problem_number,