######################################################################
 
 
-def compute_distance(walls, i, j):
+def compute_distance(walls, goal_i, goal_j):
     max_length = walls.numel()
     dist = torch.full_like(walls, max_length)
 
-    dist[i, j] = 0
+    dist[goal_i, goal_j] = 0
     pred_dist = torch.empty_like(dist)
 
     while True:
 ######################################################################
 
 
-def compute_policy(walls, i, j):
-    distance = compute_distance(walls, i, j)
+def compute_policy(walls, goal_i, goal_j):
+    distance = compute_distance(walls, goal_i, goal_j)
     distance = distance + walls.numel() * walls
 
     value = distance.new_full((4,) + distance.size(), walls.numel())
     return proba
 
 
+def stationary_density(policy, start_i, start_j):
+    probas = policy.new_zeros(policy.size()[:-1])
+    pred_probas = probas.clone()
+    probas[start_i, start_j] = 1.0
+
+    while not pred_probas.equal(probas):
+        pred_probas.copy_(probas)
+        probas.zero_()
+        probas[1:, :] = pred_probas[:-1, :] * policy[0, :-1, :]
+        probas[:-1, :] = pred_probas[1:, :] * policy[1, 1:, :]
+        probas[:, 1:] = pred_probas[:, :-1] * policy[2, :, :-1]
+        probas[:, :-1] = pred_probas[:, 1:] * policy[3, :, 1:]
+        probas[start_i, start_j] = 1.0
+
+
 ######################################################################