From: François Fleuret Date: Fri, 25 Aug 2023 17:21:50 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=0e1e208852b83f6a3d59e5caabd2f0f1f4bde94e;p=picoclvr.git Update. --- diff --git a/grid.py b/grid.py index 433cfd5..f72c8e3 100755 --- a/grid.py +++ b/grid.py @@ -19,34 +19,31 @@ name_colors = ["red", "yellow", "blue", "green", "white", "purple"] class GridFactory: def __init__( self, - height=4, - width=4, + size=4, max_nb_items=4, - max_nb_transformations=4, + max_nb_transformations=3, nb_questions=4, ): - self.height = height - self.width = width + self.size = size self.max_nb_items = max_nb_items self.max_nb_transformations = max_nb_transformations self.nb_questions = nb_questions def generate_scene(self): nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2 - col = torch.full((self.height * self.width,), -1) - shp = torch.full((self.height * self.width,), -1) + col = torch.full((self.size * self.size,), -1) + shp = torch.full((self.size * self.size,), -1) a = torch.randperm(len(name_colors) * len(name_shapes))[:nb_items] col[:nb_items] = a % len(name_colors) shp[:nb_items] = a // len(name_colors) - i = torch.randperm(self.height * self.width) + i = torch.randperm(self.size * self.size) col = col[i] shp = shp[i] - return col.reshape(self.height, self.width), shp.reshape( - self.height, self.width - ) + return col.reshape(self.size, self.size), shp.reshape(self.size, self.size) def random_transformations(self, scene): col, shp = scene + descriptions = [] nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item() transformations = torch.randint(5, (nb_transformations,)) @@ -68,30 +65,32 @@ class GridFactory: col, shp = col.flip(1).t(), shp.flip(1).t() descriptions += [" rotate 270 degrees"] - return (col.contiguous(), shp.contiguous()), descriptions + col, shp = col.contiguous(), shp.contiguous() + + return (col, shp), descriptions def print_scene(self, scene): col, shp = scene - # for i in range(self.height): - # for j in range(self.width): + # for i in range(self.size): + # for j in range(self.size): # if col[i,j] >= 0: # print(f"at ({i},{j}) {name_colors[col[i,j]]} {name_shapes[shp[i,j]]}") - for i in range(self.height): - for j in range(self.width): + for i in range(self.size): + for j in range(self.size): if col[i, j] >= 0: print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="") elif j == 0: print(" +", end="") else: print("-+", end="") - if j < self.width - 1: + if j < self.size - 1: print("--", end="") else: print("") - if i < self.height - 1: - for j in range(self.width - 1): + if i < self.size - 1: + for j in range(self.size - 1): print(" | ", end="") print(" |") @@ -100,8 +99,8 @@ class GridFactory: properties = [] - for i in range(self.height): - for j in range(self.width): + for i in range(self.size): + for j in range(self.size): if col[i, j] >= 0: n = f"{name_colors[col[i,j]]} {name_shapes[shp[i,j]]}" properties += [f"a {n} at {i} {j}"] @@ -113,21 +112,21 @@ class GridFactory: properties = [] - for i1 in range(self.height): - for j1 in range(self.width): + for i1 in range(self.size): + for j1 in range(self.size): if col[i1, j1] >= 0: n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}" properties += [f"there is a {n1}"] - if i1 < self.height // 2: + if i1 < self.size // 2: properties += [f"a {n1} is in the top half"] - if i1 >= self.height // 2: + if i1 >= self.size // 2: properties += [f"a {n1} is in the bottom half"] - if j1 < self.width // 2: + if j1 < self.size // 2: properties += [f"a {n1} is in the left half"] - if j1 >= self.width // 2: + if j1 >= self.size // 2: properties += [f"a {n1} is in the right half"] - for i2 in range(self.height): - for j2 in range(self.width): + for i2 in range(self.size): + for j2 in range(self.size): if col[i2, j2] >= 0: n2 = f"{name_colors[col[i2,j2]]} {name_shapes[shp[i2,j2]]}" if i1 > i2: @@ -153,22 +152,22 @@ class GridFactory: scene, transformations = self.random_transformations(scene) + # transformations=[] + for a in range(10): col, shp = scene col, shp = col.view(-1), shp.view(-1) p = torch.randperm(col.size(0)) col, shp = col[p], shp[p] other_scene = ( - col.view(self.height, self.width), - shp.view(self.height, self.width), + col.view(self.size, self.size), + shp.view(self.size, self.size), ) # other_scene = self.generate_scene() false = list(set(self.all_properties(other_scene)) - set(true)) if len(false) >= self.nb_questions: break - # print(f"{a=}") - if a < 10: break diff --git a/main.py b/main.py index 00e19ac..704dff5 100755 --- a/main.py +++ b/main.py @@ -99,6 +99,11 @@ parser.add_argument("--rpl_nb_runs", type=int, default=5) parser.add_argument("--rpl_no_prog", action="store_true", default=False) +############################## +# grid options + +parser.add_argument("--grid_size", type=int, default=6) + ############################## # picoclvr options @@ -517,8 +522,7 @@ elif args.task == "grid": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.batch_size, - height=args.picoclvr_height, - width=args.picoclvr_width, + size=args.grid_size, logger=log_string, device=device, ) diff --git a/tasks.py b/tasks.py index 0ab1823..2c2f914 100755 --- a/tasks.py +++ b/tasks.py @@ -1459,8 +1459,7 @@ class Grid(Task): nb_train_samples, nb_test_samples, batch_size, - height, - width, + size, logger=None, device=torch.device("cpu"), ): @@ -1468,7 +1467,7 @@ class Grid(Task): self.device = device self.batch_size = batch_size - self.grid_factory = grid.GridFactory(height=height, width=width) + self.grid_factory = grid.GridFactory(size=size) if logger is not None: logger(