# https://creativecommons.org/publicdomain/zero/1.0/
 
 
-
 # Written by Francois Fleuret <francois@fleuret.org>
 
 # This code implement a simple system to manipulate formal
             else:
                 return False
 
-        if match("([1-9]) top"):
+        if match("([1-9]) is_in_top_half"):
             (a,) = g[0]
-            return self.row[:, a] < self.grid_height // 4
-        elif match("([1-9]) bottom"):
+            return self.row[:, a] < self.grid_height // 2
+        elif match("([1-9]) is_in_bottom_half"):
             (a,) = g[0]
-            return self.row[:, a] >= (self.grid_height * 3) // 4
-        elif match("([1-9]) left"):
+            return self.row[:, a] >= self.grid_height // 2
+        elif match("([1-9]) is_on_left_side"):
             (a,) = g[0]
-            return self.col[:, a] < self.grid_width // 4
-        elif match("([1-9]) right"):
+            return self.col[:, a] < self.grid_width // 2
+        elif match("([1-9]) is_on_right_side"):
             (a,) = g[0]
-            return self.col[:, a] >= (self.grid_width * 3) // 4
+            return self.col[:, a] >= self.grid_width // 2
         elif match("([1-9]) next_to ([1-9])"):
             a, b = g[0]
             return (self.row[:, a] - self.row[:, b]).abs() + (
                 self.col[:, a] - self.col[:, b]
             ).abs() <= 1
-        elif match("([1-9]) below_of ([1-9])"):
+        elif match("([1-9]) is_below ([1-9])"):
             a, b = g[0]
             return self.row[:, a] > self.row[:, b]
-        elif match("([1-9]) above ([1-9])"):
+        elif match("([1-9]) is_above ([1-9])"):
             a, b = g[0]
             return self.row[:, a] < self.row[:, b]
-        elif match("([1-9]) left_of ([1-9])"):
+        elif match("([1-9]) is_left_of ([1-9])"):
             a, b = g[0]
             return self.col[:, a] < self.col[:, b]
-        elif match("([1-9]) right_of ([1-9])"):
+        elif match("([1-9]) is_right_of ([1-9])"):
             a, b = g[0]
             return self.col[:, a] > self.col[:, b]
-        elif match("([1-9]) ([1-9]) diagonal"):
+        elif match("([1-9]) ([1-9]) parallel_to_diagonal"):
             a, b = g[0]
             return (self.col[:, a] - self.col[:, b]).abs() == (
                 self.row[:, a] - self.row[:, b]
             a, b = g[0]
             return self.row[:, a] == self.row[:, b]
 
-        elif match("([1-9]) ([1-9]) ([1-9]) aligned"):
+        elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"):
             a, b, c = g[0]
             return (self.col[:, a] - self.col[:, b]) * (
                 self.row[:, a] - self.row[:, c]
                 & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b])
             )
 
-        elif match("([1-9]) further_away_from ([1-9]) than ([1-9])"):
+        elif match("([1-9]) is_equidistant_from ([1-9]) and ([1-9])"):
+            a, b, c = g[0]
+            return (self.col[:, a] - self.col[:, b]) ** 2 + (
+                self.row[:, a] - self.row[:, b]
+            ) ** 2 == (self.col[:, a] - self.col[:, c]) ** 2 + (
+                self.row[:, a] - self.row[:, c]
+            ) ** 2
+
+        elif match("([1-9]) is_further_away_from ([1-9]) than ([1-9])"):
             a, b, c = g[0]
             return (self.col[:, a] - self.col[:, b]) ** 2 + (
                 self.row[:, a] - self.row[:, b]
                 self.row[:, a] - self.row[:, c]
             ) ** 2
 
-        elif match("([1-9]) ([1-9]) ([1-9]) right_angle"):
+        elif match("([1-9]) ([1-9]) ([1-9]) make_right_angle"):
             a, b, c = g[0]
             return (self.col[:, a] - self.col[:, b]) * (
                 self.col[:, c] - self.col[:, b]
 
 ######################################################################
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+if __name__ == "__main__":
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-grid = FormalGrid(device=device)
+    grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device)
 
-grid_set = grid.new_grid_set(
-    [
-        "4 top",
-        "4 right",
-        "1 top",
-        "1 left",
-        "1 left_of 2",
-        "2 left_of 3",
-        "1 2 4 right_angle",
-        "1 2 3 aligned",
-        "2 further_away_from 3 than 4",
-    ],
-)
+    grid_set = grid.new_grid_set(
+        [
+            "1 2 3 make_right_angle",
+            "2 3 4 make_right_angle",
+            "3 4 1 make_right_angle",
+            "2 is_equidistant_from 1 and 3",
+            "1 is_above 4",
+        ],
+    )
 
-print(f"There are {grid_set.long().sum().item()} configurations")
+    print(f"There are {grid_set.long().sum().item()} configurations")
 
-for v in grid.views(grid_set):
-    print(v)
+    for v in grid.views(grid_set):
+        print(v)