From: Francois Fleuret Date: Thu, 15 Jun 2017 19:23:59 +0000 (+0200) Subject: Added storage compression / decompression functions to prepare for sets of 1M samples. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=51d45bec8caf5896a113e79475abd7d2df4a646d;p=pysvrt.git Added storage compression / decompression functions to prepare for sets of 1M samples. --- diff --git a/svrt.c b/svrt.c index 307fcf6..102084b 100644 --- a/svrt.c +++ b/svrt.c @@ -27,6 +27,65 @@ #include "svrt_generator.h" +THByteStorage *compress(THByteStorage *x) { + long k, g, n; + + k = 0; n = 0; + while(k < x->size) { + g = 0; + while(k < x->size && x->data[k] == 255 && g < 255) { g++; k++; } + n++; + if(k < x->size && g < 255) { k++; } + } + + if(x->data[k-1] == 0) { + n++; + } + + THByteStorage *result = THByteStorage_newWithSize(n); + + k = 0; n = 0; + while(k < x->size) { + g = 0; + while(k < x->size && x->data[k] == 255 && g < 255) { g++; k++; } + result->data[n++] = g; + if(k < x->size && g < 255) { k++; } + } + if(x->data[k-1] == 0) { + result->data[n++] = 0; + } + + return result; +} + +THByteStorage *uncompress(THByteStorage *x) { + long k, g, n; + + k = 0; + for(n = 0; n < x->size - 1; n++) { + k = k + x->data[n]; + if(x->data[n] < 255) { k++; } + } + k = k + x->data[n]; + + THByteStorage *result = THByteStorage_newWithSize(k); + + k = 0; + for(n = 0; n < x->size - 1; n++) { + for(g = 0; g < x->data[n]; g++) { + result->data[k++] = 255; + } + if(x->data[n] < 255) { + result->data[k++] = 0; + } + } + for(g = 0; g < x->data[n]; g++) { + result->data[k++] = 255; + } + + return result; +} + THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels) { struct VignetteSet vs; long nb_vignettes; diff --git a/svrt.h b/svrt.h index 77b8b46..94020c4 100644 --- a/svrt.h +++ b/svrt.h @@ -23,4 +23,8 @@ * */ +THByteStorage *compress(THByteStorage *x); + +THByteStorage *uncompress(THByteStorage *x); + THByteTensor *generate_vignettes(long n_problem, THLongTensor *labels); diff --git a/test-svrt.py b/test-svrt.py index 5c16069..5f38fa9 100755 --- a/test-svrt.py +++ b/test-svrt.py @@ -41,6 +41,8 @@ labels.narrow(0, 0, labels.size(0)//2).fill_(1) x = svrt.generate_vignettes(4, labels) +print('compression factor {:f}'.format(x.storage().size() / svrt.compress(x.storage()).size())) + x = x.view(x.size(0), 1, x.size(1), x.size(2)) x.div_(255)