From: Francois Fleuret Date: Mon, 20 Jun 2022 06:14:46 +0000 (+0200) Subject: Finalized PicoCLVR with "many colors". X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=046f35f38d629c9854104e855a53f0142449138f;p=mygpt.git Finalized PicoCLVR with "many colors". --- diff --git a/main.py b/main.py index a31284e..3bf7587 100755 --- a/main.py +++ b/main.py @@ -111,12 +111,20 @@ import picoclvr class TaskPicoCLVR(Task): - def __init__(self, batch_size, height = 6, width = 8, device = torch.device('cpu')): + def __init__(self, batch_size, + height = 6, width = 8, many_colors = False, + device = torch.device('cpu')): + self.batch_size = batch_size self.device = device nb = args.data_size if args.data_size > 0 else 250000 - descr = picoclvr.generate(nb, height = height, width = width) + descr = picoclvr.generate( + nb, + height = height, width = width, + many_colors = many_colors + ) + descr = [ s.strip().split(' ') for s in descr ] l = max([ len(s) for s in descr ]) descr = [ s + [ '' ] * (l - len(s)) for s in descr ] diff --git a/picoclvr.py b/picoclvr.py index f4d7a65..712da17 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -71,7 +71,9 @@ color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] ) ###################################################################### -def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = 10, many_colors = False): +def generate(nb, height = 6, width = 8, + max_nb_squares = 5, max_nb_statements = 10, + many_colors = False): nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares