From: Francois Fleuret Date: Wed, 14 Jun 2017 16:27:51 +0000 (+0200) Subject: svrt.generate_vignettes now takes a 1d label tensor as arguments. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=24368498f3065e8a4be34c5e8e2b68f9d1220f7d;p=pysvrt.git svrt.generate_vignettes now takes a 1d label tensor as arguments. --- diff --git a/svrt.c b/svrt.c index fdee66f..1a2449b 100644 --- a/svrt.c +++ b/svrt.c @@ -27,14 +27,25 @@ #include "svrt_generator.h" -THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) { +THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels) { struct VignetteSet vs; + long nb_vignettes; long st0, st1, st2; long v, i, j; + long *m, *l; unsigned char *a, *b; - svrt_generate_vignettes(n_problem, nb_vignettes, &vs); - printf("SANITY %d %d %d\n", vs.nb_vignettes, vs.width, vs.height); + nb_vignettes = THLongTensor_size(labels, 0); + m = THLongTensor_storage(labels)->data + THLongTensor_storageOffset(labels); + st0 = THLongTensor_stride(labels, 0); + l = (long *) malloc(sizeof(long) * nb_vignettes); + for(v = 0; v < nb_vignettes; v++) { + l[v] = *m; + m += st0; + } + + svrt_generate_vignettes(n_problem, nb_vignettes, l, &vs); + free(l); THLongStorage *size = THLongStorage_newWithSize(3); size->data[0] = vs.nb_vignettes; @@ -61,5 +72,7 @@ THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) { } } + free(vs.data); + return result; } diff --git a/svrt.h b/svrt.h index 4335df3..77b8b46 100644 --- a/svrt.h +++ b/svrt.h @@ -23,4 +23,4 @@ * */ -THByteTensor *generate_vignettes(long n_problem, long nb_images); +THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels); diff --git a/svrt_generator.cc b/svrt_generator.cc index 80cfd12..90f781d 100644 --- a/svrt_generator.cc +++ b/svrt_generator.cc @@ -153,7 +153,8 @@ struct VignetteSet { unsigned char *data; }; -void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) { +void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels, + VignetteSet *result) { Vignette tmp; VignetteGenerator *vg = new_generator(n_problem); @@ -165,7 +166,7 @@ void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *resul unsigned char *s = result->data; for(int i = 0; i < nb_vignettes; i++) { - vg->generate(drand48() < 0.5 ? 1 : 0, &tmp); + vg->generate(labels[i], &tmp); int *r = tmp.content; for(int k = 0; k < Vignette::width * Vignette::height; k++) { *s++ = *r++; diff --git a/svrt_generator.h b/svrt_generator.h index bdfe5c1..7f6a3ad 100644 --- a/svrt_generator.h +++ b/svrt_generator.h @@ -35,7 +35,8 @@ struct VignetteSet { unsigned char *data; }; -void svrt_generate_vignettes(int n_problem, int nb_vignettes, struct VignetteSet *result); + void svrt_generate_vignettes(int n_problem, int nb_vignettes, long *labels, + struct VignetteSet *result); #ifdef __cplusplus } diff --git a/test-svrt.py b/test-svrt.py index 6b5f826..9aa2d59 100755 --- a/test-svrt.py +++ b/test-svrt.py @@ -36,10 +36,15 @@ from torchvision import datasets, transforms, utils from _ext import svrt -train_set = svrt.generate_vignettes(12, 64) +labels = torch.LongTensor(12).zero_() +labels.narrow(0, 0, labels.size(0)//2).fill_(1) + +train_set = svrt.generate_vignettes(4, labels) print(str(type(train_set)), train_set.size()) +train_set = train_set.view(train_set.size(0), 1, train_set.size(1), train_set.size(2)) + train_set.div_(255) -torchvision.utils.save_image(train_set.view(train_set.size(0), 1, train_set.size(1), train_set.size(2)), 'example.png') +torchvision.utils.save_image(train_set, 'example.png')