From: François Fleuret Date: Sat, 26 Oct 2024 21:53:32 +0000 (+0200) Subject: Update. X-Git-Url: https://ant.fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=03f7803f6770063c07b8000e06cbe6c694cdaf41;p=pytorch.git Update. --- diff --git a/grid.py b/grid.py index 991c44f..7168491 100755 --- a/grid.py +++ b/grid.py @@ -4,7 +4,6 @@ # https://creativecommons.org/publicdomain/zero/1.0/ - # Written by Francois Fleuret # This code implement a simple system to manipulate formal @@ -72,36 +71,36 @@ class FormalGrid: else: return False - if match("([1-9]) top"): + if match("([1-9]) is_in_top_half"): (a,) = g[0] - return self.row[:, a] < self.grid_height // 4 - elif match("([1-9]) bottom"): + return self.row[:, a] < self.grid_height // 2 + elif match("([1-9]) is_in_bottom_half"): (a,) = g[0] - return self.row[:, a] >= (self.grid_height * 3) // 4 - elif match("([1-9]) left"): + return self.row[:, a] >= self.grid_height // 2 + elif match("([1-9]) is_on_left_side"): (a,) = g[0] - return self.col[:, a] < self.grid_width // 4 - elif match("([1-9]) right"): + return self.col[:, a] < self.grid_width // 2 + elif match("([1-9]) is_on_right_side"): (a,) = g[0] - return self.col[:, a] >= (self.grid_width * 3) // 4 + return self.col[:, a] >= self.grid_width // 2 elif match("([1-9]) next_to ([1-9])"): a, b = g[0] return (self.row[:, a] - self.row[:, b]).abs() + ( self.col[:, a] - self.col[:, b] ).abs() <= 1 - elif match("([1-9]) below_of ([1-9])"): + elif match("([1-9]) is_below ([1-9])"): a, b = g[0] return self.row[:, a] > self.row[:, b] - elif match("([1-9]) above ([1-9])"): + elif match("([1-9]) is_above ([1-9])"): a, b = g[0] return self.row[:, a] < self.row[:, b] - elif match("([1-9]) left_of ([1-9])"): + elif match("([1-9]) is_left_of ([1-9])"): a, b = g[0] return self.col[:, a] < self.col[:, b] - elif match("([1-9]) right_of ([1-9])"): + elif match("([1-9]) is_right_of ([1-9])"): a, b = g[0] return self.col[:, a] > self.col[:, b] - elif match("([1-9]) ([1-9]) diagonal"): + elif match("([1-9]) ([1-9]) parallel_to_diagonal"): a, b = g[0] return (self.col[:, a] - self.col[:, b]).abs() == ( self.row[:, a] - self.row[:, b] @@ -113,7 +112,7 @@ class FormalGrid: a, b = g[0] return self.row[:, a] == self.row[:, b] - elif match("([1-9]) ([1-9]) ([1-9]) aligned"): + elif match("([1-9]) ([1-9]) ([1-9]) are_aligned"): a, b, c = g[0] return (self.col[:, a] - self.col[:, b]) * ( self.row[:, a] - self.row[:, c] @@ -129,7 +128,15 @@ class FormalGrid: & (self.row[:, a] + self.row[:, c] == 2 * self.row[:, b]) ) - elif match("([1-9]) further_away_from ([1-9]) than ([1-9])"): + elif match("([1-9]) is_equidistant_from ([1-9]) and ([1-9])"): + a, b, c = g[0] + return (self.col[:, a] - self.col[:, b]) ** 2 + ( + self.row[:, a] - self.row[:, b] + ) ** 2 == (self.col[:, a] - self.col[:, c]) ** 2 + ( + self.row[:, a] - self.row[:, c] + ) ** 2 + + elif match("([1-9]) is_further_away_from ([1-9]) than ([1-9])"): a, b, c = g[0] return (self.col[:, a] - self.col[:, b]) ** 2 + ( self.row[:, a] - self.row[:, b] @@ -137,7 +144,7 @@ class FormalGrid: self.row[:, a] - self.row[:, c] ) ** 2 - elif match("([1-9]) ([1-9]) ([1-9]) right_angle"): + elif match("([1-9]) ([1-9]) ([1-9]) make_right_angle"): a, b, c = g[0] return (self.col[:, a] - self.col[:, b]) * ( self.col[:, c] - self.col[:, b] @@ -185,25 +192,22 @@ class FormalGrid: ###################################################################### -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -grid = FormalGrid(device=device) + grid = FormalGrid(grid_height=8, grid_width=8, nb_symbols=4, device=device) -grid_set = grid.new_grid_set( - [ - "4 top", - "4 right", - "1 top", - "1 left", - "1 left_of 2", - "2 left_of 3", - "1 2 4 right_angle", - "1 2 3 aligned", - "2 further_away_from 3 than 4", - ], -) + grid_set = grid.new_grid_set( + [ + "1 2 3 make_right_angle", + "2 3 4 make_right_angle", + "3 4 1 make_right_angle", + "2 is_equidistant_from 1 and 3", + "1 is_above 4", + ], + ) -print(f"There are {grid_set.long().sum().item()} configurations") + print(f"There are {grid_set.long().sum().item()} configurations") -for v in grid.views(grid_set): - print(v) + for v in grid.views(grid_set): + print(v) diff --git a/tinymnist.py b/tinymnist.py index 896477e..f662be6 100755 --- a/tinymnist.py +++ b/tinymnist.py @@ -70,14 +70,14 @@ test_input.sub_(mu).div_(std) start_time = time.perf_counter() for k in range(nb_epochs): - acc_loss = 0.0 + acc_train_loss = 0.0 for input, targets in zip( train_input.split(batch_size), train_targets.split(batch_size) ): output = model(input) loss = criterion(output, targets) - acc_loss += loss.item() + acc_train_loss += loss.item() * input.size(0) optimizer.zero_grad() loss.backward() @@ -92,6 +92,8 @@ for k in range(nb_epochs): test_error = nb_test_errors / test_input.size(0) duration = time.perf_counter() - start_time - print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%") + print( + f"loss {k} {duration:.02f}s {acc_train_loss/train_input.size(0):.02f} {test_error*100:.02f}%" + ) ######################################################################