From: Francois Fleuret Date: Fri, 15 Jul 2022 15:07:47 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=c8d0cf6842db19f84a78c1b3a4d2666b323a5d4a;p=mygpt.git Update. --- diff --git a/main.py b/main.py index 85cf4cf..11cf0a3 100755 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ parser.add_argument('--log_filename', type = str, default = 'train.log') parser.add_argument('--download', - type = bool, default = False) + action='store_true', default = False) parser.add_argument('--seed', type = int, default = 0) @@ -67,11 +67,14 @@ parser.add_argument('--dropout', type = float, default = 0.1) parser.add_argument('--synthesis_sampling', - type = bool, default = True) + action='store_true', default = True) parser.add_argument('--checkpoint_name', type = str, default = 'checkpoint.pth') +parser.add_argument('--picoclvr_many_colors', + action='store_true', default = False) + ###################################################################### args = parser.parse_args() @@ -353,7 +356,7 @@ if args.data == 'wiki103': elif args.data == 'mnist': task = TaskMNIST(batch_size = args.batch_size, device = device) elif args.data == 'picoclvr': - task = TaskPicoCLVR(batch_size = args.batch_size, device = device) + task = TaskPicoCLVR(batch_size = args.batch_size, many_colors = args.picoclvr_many_colors, device = device) else: raise ValueError(f'Unknown dataset {args.data}.') diff --git a/picoclvr.py b/picoclvr.py index 601bdf7..6dd8114 100755 --- a/picoclvr.py +++ b/picoclvr.py @@ -71,6 +71,25 @@ color_tokens = dict( [ (n, c) for n, c in zip(color_names, colors) ] ) ###################################################################### +def all_properties(height, width, nb_squares, square_i, square_j, square_c): + s = [ ] + + for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: + s += [ f'there is {c}' ] + + if square_i[r] >= height - height//3: s += [ f'{c} bottom' ] + if square_i[r] < height//3: s += [ f'{c} top' ] + if square_j[r] >= width - width//3: s += [ f'{c} right' ] + if square_j[r] < width//3: s += [ f'{c} left' ] + + for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: + if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ] + if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ] + if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ] + if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ] + + return s + def generate(nb, height = 6, width = 8, max_nb_squares = 5, max_nb_statements = 10, many_colors = False): @@ -93,21 +112,7 @@ def generate(nb, height = 6, width = 8, # generates all the true relations - s = [ ] - - for r, c in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: - s += [ f'there is {c}' ] - - if square_i[r] >= height - height//3: s += [ f'{c} bottom' ] - if square_i[r] < height//3: s += [ f'{c} top' ] - if square_j[r] >= width - width//3: s += [ f'{c} right' ] - if square_j[r] < width//3: s += [ f'{c} left' ] - - for t, d in [ (k, color_names[square_c[k]]) for k in range(nb_squares) ]: - if square_i[r] > square_i[t]: s += [ f'{c} below {d}' ] - if square_i[r] < square_i[t]: s += [ f'{c} above {d}' ] - if square_j[r] > square_j[t]: s += [ f'{c} right of {d}' ] - if square_j[r] < square_j[t]: s += [ f'{c} left of {d}' ] + s = all_properties(height, width, nb_squares, square_i, square_j, square_c) # pick at most max_nb_statements at random diff --git a/result_picoclvr_0007.png b/result_picoclvr_0007.png index e36efb6..569bfc3 100644 Binary files a/result_picoclvr_0007.png and b/result_picoclvr_0007.png differ