Skip to content

Commit

Permalink
PUBLIC: Add predecessor_pointers_to_permutation_matrix and `permuta…
Browse files Browse the repository at this point in the history
…tion_matrix_to_predecessor_pointers` methods to probing.

PiperOrigin-RevId: 638431530
  • Loading branch information
CLRSDev authored and copybara-github committed May 30, 2024
1 parent 86bf117 commit 4b1c035
Showing 1 changed file with 89 additions and 1 deletion.
90 changes: 89 additions & 1 deletion clrs/_src/probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_Type = specs.Type
_OutputClass = specs.OutputClass

_Array = np.ndarray
_Array = np.ndarray | jax.Array
_Data = Union[_Array, List[_Array]]
_DataOrType = Union[_Data, str]

Expand Down Expand Up @@ -312,6 +312,94 @@ def strings_pred(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
return probe


@functools.partial(jnp.vectorize, signature='(n)->(n,n)')
def predecessor_pointers_to_permutation_matrix(
pointers: jnp.ndarray) -> jnp.ndarray:
"""Converts predecessor pointers to a permutation matrix.
This function assumes that the pointers represent a linear order of the nodes
(akin to a linked list), where each node points to its predecessor and the
first node points to itself. It returns a permutation matrix `P` that sorts
the nodes into the order implied by the pointers.
Example:
```
pointers = [2, 1, 1]
P = [[0, 1, 0],
[0, 0, 1],
[1, 0, 0]]
```
Args:
pointers: array of shape [N] containing pointers. The pointers are assumed
to describe a linear order such that `pointers[i]` is the predecessor
of node `i`.
Returns:
Permutation matrix `P` of shape [N, N]. Given node features `x` of shape
[N, F], `P @ x` returns sorted node features.
"""
# Find the index of the last node: it's the node that no other node points to.
nb_nodes = pointers.shape[-1]
pointers_one_hot = jax.nn.one_hot(pointers, nb_nodes)
last = pointers_one_hot.sum(-2).argmin()

# Initialize permutation matrix with zeros.
perm = jnp.zeros([nb_nodes, nb_nodes])

for i in range(nb_nodes - 1, -1, -1):
# perm[i, last] = 1
perm += (
jax.nn.one_hot(i, nb_nodes)[..., None] * jax.nn.one_hot(last, nb_nodes))
last = pointers[last]

return perm


@functools.partial(jnp.vectorize, signature='(n,n)->(n)')
def permutation_matrix_to_predecessor_pointers(
perm: jnp.ndarray) -> jnp.ndarray:
"""Converts a permutation matrix to predecessor pointers.
Given an [N, N] permutation matrix `P` that sorts a list of nodes, this
function returns predecessor pointers that encode the sorted order.
Example:
```
P = [[0, 1, 0],
[0, 0, 1],
[1, 0, 0]]
pointers = [2, 1, 1]
```
Args:
perm: permutation matrix of shape [N, N].
Returns:
An array of shape [N] containing predecessor pointers.
"""
nb_nodes = perm.shape[-1]

# Initialize pointers to zeros.
pointers = jnp.zeros([nb_nodes], dtype=int)

# idx[i] is the index of the node at position i in the sorted order
idx = perm.argmax(-1)

# pointers[idx[0]] = idx[0]
pointers += idx[0] * jax.nn.one_hot(idx[0], nb_nodes)

for i in range(1, nb_nodes):
# pointers[idx[i]] = idx[i - 1]
pointers += idx[i - 1] * jax.nn.one_hot(idx[i], nb_nodes)

# Ensure that the pointers are in the valid range even if the input is badly
# formatted. This has no effect for well-formatted input.
pointers = jnp.minimum(pointers, nb_nodes - 1)

return pointers


@functools.partial(jnp.vectorize, signature='(n)->(n,n),(n)')
def predecessor_to_cyclic_predecessor_and_first(
pointers: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand Down

0 comments on commit 4b1c035

Please sign in to comment.