From 823ed2babf4a7144a1832487e7c911e6933d5647 Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Mon, 25 Jul 2022 18:15:00 +0200 Subject: [PATCH] Update. --- main.py | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index b579177..77c4b9e 100755 --- a/main.py +++ b/main.py @@ -130,36 +130,39 @@ class TaskPicoCLVR(Task): height, width, many_colors = False, device = torch.device('cpu')): + def generate_descr(nb): + descr = picoclvr.generate( + nb, + height = self.height, width = self.width, + many_colors = many_colors + ) + + descr = [ s.strip().split(' ') for s in descr ] + l = max([ len(s) for s in descr ]) + descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + + return descr + self.height = height self.width = width self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 - descr = picoclvr.generate( - nb, - height = self.height, width = self.width, - many_colors = many_colors - ) - - # self.test_descr = descr[:nb // 5] - # self.train_descr = descr[nb // 5:] - - descr = [ s.strip().split(' ') for s in descr ] - l = max([ len(s) for s in descr ]) - descr = [ s + [ '' ] * (l - len(s)) for s in descr ] + self.train_descr = generate_descr((nb * 4) // 5) + self.test_descr = generate_descr((nb * 1) // 5) tokens = set() - for s in descr: - for t in s: tokens.add(t) + for d in [ self.train_descr, self.test_descr ]: + for s in d: + for t in s: tokens.add(t) self.token2id = dict([ (t, n) for n, t in enumerate(tokens) ]) self.id2token = dict([ (n, t) for n, t in enumerate(tokens) ]) - t = [ [ self.token2id[u] for u in s ] for s in descr ] - data_input = torch.tensor(t, device = self.device) - - self.test_input = data_input[:nb // 5] - self.train_input = data_input[nb // 5:] + t = [ [ self.token2id[u] for u in s ] for s in self.train_descr ] + self.train_input = torch.tensor(t, device = self.device) + t = [ [ self.token2id[u] for u in s ] for s in self.test_descr ] + self.test_input = torch.tensor(t, device = self.device) def batches(self, split = 'train'): assert split in { 'train', 'test' } -- 2.39.5