From: Francois Fleuret Date: Wed, 14 Jun 2017 16:06:35 +0000 (+0200) Subject: Test now saves an example image. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=44bd6cf57f00009c7049dcd3e5600f12f2f41de5;p=pysvrt.git Test now saves an example image. --- diff --git a/svrt.c b/svrt.c index 0f53642..fdee66f 100644 --- a/svrt.c +++ b/svrt.c @@ -29,21 +29,37 @@ THByteTensor *generate_vignettes(long n_problem, long nb_vignettes) { struct VignetteSet vs; + long st0, st1, st2; + long v, i, j; + unsigned char *a, *b; svrt_generate_vignettes(n_problem, nb_vignettes, &vs); printf("SANITY %d %d %d\n", vs.nb_vignettes, vs.width, vs.height); THLongStorage *size = THLongStorage_newWithSize(3); - size->data[0] = nb_vignettes; + size->data[0] = vs.nb_vignettes; size->data[1] = vs.height; size->data[2] = vs.width; THByteTensor *result = THByteTensor_newWithSize(size, NULL); THLongStorage_free(size); - /* st0 = THByteTensor_stride(result, 0); */ - /* st1 = THByteTensor_stride(result, 1); */ - /* st2 = THByteTensor_stride(result, 2); */ + st0 = THByteTensor_stride(result, 0); + st1 = THByteTensor_stride(result, 1); + st2 = THByteTensor_stride(result, 2); + + unsigned char *r = vs.data; + for(v = 0; v < vs.nb_vignettes; v++) { + a = THByteTensor_storage(result)->data + THByteTensor_storageOffset(result) + v * st0; + for(i = 0; i < vs.height; i++) { + b = a + i * st1; + for(j = 0; j < vs.width; j++) { + *b = (unsigned char) (*r); + r++; + b += st2; + } + } + } return result; } diff --git a/svrt_generator.cc b/svrt_generator.cc index 82b7c3b..80cfd12 100644 --- a/svrt_generator.cc +++ b/svrt_generator.cc @@ -145,22 +145,34 @@ VignetteGenerator *new_generator(int nb) { extern "C" { - struct VignetteSet { - int n_problem; - int nb_vignettes; - int width; - int height; - unsigned char *data; - }; - - void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) { - VignetteGenerator *vg = new_generator(n_problem); - result->n_problem = n_problem; - result->nb_vignettes = nb_vignettes; - result->width = Vignette::width; - result->height = Vignette::height; - result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height); - delete vg; +struct VignetteSet { + int n_problem; + int nb_vignettes; + int width; + int height; + unsigned char *data; +}; + +void svrt_generate_vignettes(int n_problem, int nb_vignettes, VignetteSet *result) { + Vignette tmp; + + VignetteGenerator *vg = new_generator(n_problem); + result->n_problem = n_problem; + result->nb_vignettes = nb_vignettes; + result->width = Vignette::width; + result->height = Vignette::height; + result->data = (unsigned char *) malloc(sizeof(unsigned char) * result->nb_vignettes * result->width * result->height); + + unsigned char *s = result->data; + for(int i = 0; i < nb_vignettes; i++) { + vg->generate(drand48() < 0.5 ? 1 : 0, &tmp); + int *r = tmp.content; + for(int k = 0; k < Vignette::width * Vignette::height; k++) { + *s++ = *r++; + } } + delete vg; +} + } diff --git a/test-svrt.py b/test-svrt.py index 92fc554..6b5f826 100755 --- a/test-svrt.py +++ b/test-svrt.py @@ -24,16 +24,22 @@ import time import torch +import torchvision from torch import optim from torch import FloatTensor as Tensor from torch.autograd import Variable from torch import nn from torch.nn import functional as fn + from torchvision import datasets, transforms, utils from _ext import svrt -train_set = svrt.generate_vignettes(12, 1234) +train_set = svrt.generate_vignettes(12, 64) print(str(type(train_set)), train_set.size()) + +train_set.div_(255) + +torchvision.utils.save_image(train_set.view(train_set.size(0), 1, train_set.size(1), train_set.size(2)), 'example.png')