-            v = (
-                (trigger.cumsum(dim=1) - trigger).cumsum(dim=1)
-                + torch.randint(
-                    input.size(1) - memex_len, (input.size(0), 1), device=t.device
-                )
-            ) * memex_mask
-            assert v.min() >= 0
-            assert v.max() < input.size(1)
-            u = u * (1 - memex_mask) + v * memex_mask
-
-            new_input = input[n, u]
-            assert input.max() < vocabulary_size
-            assert new_input.max() < vocabulary_size
-            limits = trigger.clone()
-            limits[:, memex_len - 1 :] += limits[:, : -(memex_len - 1)]
-            assert limits.min() == 0
-            assert limits.max() == 1
-            new_input = new_input * (1 - limits) + marker_token * limits
-            assert marker_token < vocabulary_size
-            assert new_input.max() < vocabulary_size