Skip to content

Commit

Permalink
allow einsum to accept aliases similar to rearrange (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Apr 14, 2024
1 parent dd92ee8 commit 5599b78
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 27 deletions.
37 changes: 36 additions & 1 deletion docs/matmul.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
Haliax has two ways to do matrix multiplication (and tensor contractions more generally):
[haliax.dot][] and [haliax.einsum][]. [haliax.dot][] and [haliax.einsum][]
can both express any tensor contraction, though in different situations one or the other may be
more suitable for expressing a particular contraction.
more suitable for expressing a particular contraction In general:

- Use [haliax.dot][] when you want to express a simple matrix multiplication over one or a few axes.
- Use [haliax.einsum][] when you want to express a more complex tensor contraction.

See also the API reference for [haliax.dot][] and [haliax.einsum][] and the
[cheat sheet section](cheatsheet.md#matrix-multiplication).
Expand Down Expand Up @@ -68,6 +71,9 @@ Haliax's version of `einsum` comes in three modes: "ordered", "unordered", and "
These modes are all accessible through the same function without any flags: the syntax
of the `einsum` string determines which mode is used.

The syntax for Haliax's `einsum` is similar to [`haliax.rearrange`](rearrange.md), which
is in turn similar to [einops.rearrange](https://einops.rocks/api/rearrange/).

#### Ordered Mode

Haliax's `einsum` has an "ordered" mode that is similar to `einops.einsum`'s behavior.
Expand Down Expand Up @@ -119,6 +125,22 @@ y = hax.einsum("{H ...} -> ...", x) # shape is (W, D)

This mode is most similar to [haliax.dot][]'s behavior, though it's a bit more expressive.

You can also use axis aliases in the `einsum` string, which can be useful for expressing contractions
in library code or just for shortening the string:

```python
Height = hax.Axis("Height", 3)
Width = hax.Axis("Width", 4)
Depth = hax.Axis("Depth", 5)

x = hax.ones((Height, Width, Depth))
w = hax.ones((Depth,))

y = hax.einsum("{H W D} -> H W", x, H=Height, W=Width, D=Depth) # shape is (Height, Width)
y = hax.einsum("{D} -> ", w, D=Depth) # shape is (Height, Width)
```


#### Output Axes Mode

In "output axes" mode, you only specify the axes that should be in the output. All other
Expand All @@ -142,3 +164,16 @@ y = hax.einsum("-> D", w) # shape is (D,)

We don't recommend using this mode except in cases when you're sure of the full shape of the input arrays
or you are sure you don't want to let users implicitly batch over any axes.

Output axes mode also supports axis aliases:

```python
Height = hax.Axis("Height", 3)
Width = hax.Axis("Width", 4)
Depth = hax.Axis("Depth", 5)

x = hax.ones((Height, Width, Depth))
w = hax.ones((Depth,))
y = hax.einsum("-> Height Width", x, Height=Height, Width=Width, Depth=Depth) # shape is (Height, Width)
y = hax.einsum("-> Depth", w, Depth=Depth) # shape is (Depth,)
```
126 changes: 102 additions & 24 deletions src/haliax/_src/einsum.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import functools
from types import EllipsisType
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Tuple

import jax.lax
import jax.numpy as jnp

import haliax

from ..axis import AxisSelector, axis_name, eliminate_axes, rearrange_for_partial_order, union_axes
from ..axis import Axis, AxisSelector, axis_name, eliminate_axes, rearrange_for_partial_order, union_axes
from ..core import NamedArray
from ..jax_utils import _jittable_dg_einsum
from ..types import DTypeLike, PrecisionLike
Expand All @@ -21,6 +20,7 @@ def einsum(
precision: PrecisionLike = None,
preferred_element_type: Optional[DTypeLike] = None,
_dot_general=jax.lax.dot_general,
**axis_aliases,
) -> NamedArray:
"""Compute the tensor contraction of the input arrays according to Haliax's named variant of the Einstein summation
convention.
Expand All @@ -38,12 +38,16 @@ def einsum(
>>> hax.einsum("{H W D} -> H W", a, b)
>>> hax.einsum("{D} -> ", a, b) # same as the previous example
>>> hax.einsum("-> H W", a, b) # same as the first example
>>> # axis aliases, useful for generic code
>>> hax.einsum("{x y} -> y", a, b, x=H, y=W)
Args:
equation: The einsum equation.
arrays: The input arrays.
precision: The precision of the computation.
preferred_element_type: The preferred element type of the computation.
_dot_general: The dot_general function to use.
axis_aliases: The axis aliases to use.
Returns:
The result of the einsum.
Expand All @@ -59,17 +63,17 @@ def einsum(
# NB: we're using JAX's einsum which only supports one letter names for dims
if len(lhses) == 1 and len(lhses[0].captures) == 0 and lhses[0].is_ordered:
# case 3: get the output axes, contract the others
spec, out_axes = _output_only_named_einsum(equation, arrays, rhs)
spec, out_axes = _output_only_named_einsum(equation, arrays, rhs, axis_aliases)
elif len(lhses) == 1 and not lhses[0].is_ordered:
# case 2: some axes are named. Those named only on the lhs are contracted, the others are kept
# subcase: if there's an ellipsis on the lhs, we contract all the axes that are not named on the rhs
spec, out_axes = _unordered_einsum(arrays, equation, lhses, rhs)
spec, out_axes = _unordered_einsum(arrays, equation, lhses[0], rhs, axis_aliases)
else:
# general case: we have a normal einsum. we don't allow unordered axes here
if any(not lhs.is_ordered for lhs in lhses):
raise_parse_error("Cannot have multiple unordered axes in an einsum", equation, None)

spec, out_axes = _positional_einsum_spec(equation, arrays, lhses, rhs)
spec, out_axes = _positional_einsum_spec(equation, arrays, lhses, rhs, axis_aliases)

out_raw = _jittable_dg_einsum(
spec,
Expand All @@ -83,12 +87,14 @@ def einsum(
return haliax.auto_sharded(out)


def _unordered_einsum(arrays, equation, lhses, rhs):
used_letters: set[str] = set()
name_mappings_for_einsum: dict[str, str] = {}
lhs = lhses[0]
candidate_axes, has_ellipsis_lhs = _captures_to_axis_names(equation, lhs)
rhs_axes, has_ellipsis_rhs = _captures_to_axis_names(equation, rhs)
def _unordered_einsum(arrays, equation, lhs, rhs, axis_aliases):
candidate_axes, has_ellipsis_lhs, covered_lhs = _captures_to_axis_names(equation, lhs, axis_aliases)
rhs_axes, has_ellipsis_rhs, covered_rhs = _captures_to_axis_names(equation, rhs, axis_aliases)

for alias_name in axis_aliases:
if alias_name not in covered_lhs and alias_name not in covered_rhs:
raise_parse_error(f"Axis alias {alias_name} not used in the einsum", equation, None)

all_input_axes = _all_input_axes(arrays)
if has_ellipsis_rhs:
out_axes = rearrange_for_partial_order(rhs_axes, all_input_axes)
Expand All @@ -105,15 +111,17 @@ def _unordered_einsum(arrays, equation, lhses, rhs):
# what people expect
rhs_axes = [Ellipsis] + rhs_axes # type: ignore
out_axes = rearrange_for_partial_order(rhs_axes, almost_out_axes)
spec = _make_einsum_spec(name_mappings_for_einsum, used_letters, arrays, out_axes)
spec = _make_einsum_spec(arrays, out_axes)
return spec, out_axes


def _output_only_named_einsum(equation, arrays, rhs):
used_letters: set[str] = set()
name_mappings_for_einsum: dict[str, str] = {}

def _output_only_named_einsum(equation, arrays, rhs, axis_aliases):
out_axes = []
used_axes = set()
used_aliases = set()

input_axis_names = set(ax.name for ax in _all_input_axes(arrays))

for capture in rhs.captures:
if capture is Ellipsis:
raise_parse_error("Can't use ellipsis on the rhs of an einsum without an lhs", equation, None)
Expand All @@ -125,23 +133,51 @@ def _output_only_named_einsum(equation, arrays, rhs):
)
else:
name = capture.binding
used_aliases.add(name)

if name in axis_aliases:
# this could be axis or a name. if an axis, need to assert the size
axis = axis_aliases[name]
if isinstance(axis, Axis):
_check_axis_size_consistency(arrays, axis, name)
ax_name = axis_name(axis)

if ax_name in used_axes:
raise_parse_error(
f"Axis {name} occurs multiple times on the rhs. Probably because of multiple aliasing?",
equation,
capture.char_range,
)

name = ax_name
used_axes.add(name)

if name in out_axes:
raise_parse_error(f"Axis name {name} occurs multiple times on the rhs", equation, capture.char_range)
raise_parse_error(
f"Axis capture {name} occurs multiple times on the rhs", equation, capture.char_range
)

if name not in input_axis_names:
raise_parse_error(f"Axis {name} not found in any of the input arrays", equation, capture.char_range)

out_axes.append(name)

spec = _make_einsum_spec(name_mappings_for_einsum, used_letters, arrays, out_axes)
_check_for_unused_aliases(axis_aliases, used_aliases, equation)

spec = _make_einsum_spec(arrays, out_axes)
return spec, out_axes


def _positional_einsum_spec(equation, arrays, lhses, rhs):
def _positional_einsum_spec(equation, arrays, lhses, rhs, axis_aliases):
used_letters: set[str] = set()
name_mappings_for_einsum: dict[str, str] = {}
used_aliases = set()

if len(lhses) != len(arrays):
raise ValueError(f"Number of lhses ({len(lhses)}) does not match number of arrays ({len(arrays)})")
table = AliasTable()

# For this function, axis_aliases exists entirely for checking axis sizes against what's in the arrays
table = AliasTable(axis_aliases)
# ok, we're going to lead pretty heavily on einsum here. We just need to figure out the names of the axes
# and do any error checking (that there are no mismatched names)
# once we do that, we can pass a slightly modified spec to einsum (namely that we shorten the names of the axes)
Expand All @@ -164,6 +200,9 @@ def _positional_einsum_spec(equation, arrays, lhses, rhs):
raise_parse_error("Parenthesized axes are not currently supported", equation, capture.char_range)
else:
name = capture.binding
if name in axis_aliases:
used_aliases.add(name)

if axis_off >= len(a.axes):
raise ValueError("Mismatched number of axes in einsum")
table.bind_alias(name, a.axes[axis_off], equation, capture.char_range)
Expand All @@ -184,6 +223,9 @@ def _positional_einsum_spec(equation, arrays, lhses, rhs):
break
else:
name = capture.binding
if name in axis_aliases:
used_aliases.add(name)

if axis_off < final_lhs_axis_off:
raise ValueError("Mismatched number of axes in einsum")
table.bind_alias(name, a.axes[axis_off], equation, capture.char_range)
Expand Down Expand Up @@ -230,6 +272,8 @@ def _positional_einsum_spec(equation, arrays, lhses, rhs):
spec += letter
out_axes.append(axis)

_check_for_unused_aliases(axis_aliases, used_aliases, equation)

if has_ellipsis_rhs:
all_input_axes = _all_input_axes(arrays)
# eliminate the axes that are contracted
Expand All @@ -244,7 +288,8 @@ def _all_input_axes(arrays):
return ensure_tuple(functools.reduce(union_axes, (a.axes for a in arrays), ())) # type: ignore


def _captures_to_axis_names(equation, lhs) -> Tuple[list[str | EllipsisType], bool]:
def _captures_to_axis_names(equation, lhs, aliases) -> Tuple[list[str | EllipsisType], bool, set[str]]:
covered_aliases = set()
candidate_axes: list[str | EllipsisType] = []
has_ellipsis = False
for capture in lhs.captures:
Expand All @@ -255,11 +300,17 @@ def _captures_to_axis_names(equation, lhs) -> Tuple[list[str | EllipsisType], bo
raise_parse_error("Parenthesized axes are not currently supported", equation, capture.char_range)
else:
name = capture.binding
if name in aliases:
covered_aliases.add(name)
axis = aliases[name]
name = axis_name(axis)
candidate_axes.append(name)
return candidate_axes, has_ellipsis
return candidate_axes, has_ellipsis, covered_aliases


def _make_einsum_spec(name_mappings_for_einsum, used_letters, arrays, out_axes):
def _make_einsum_spec(arrays, out_axes):
name_mappings_for_einsum: dict[str, str] = {}
used_letters: set[str] = set()
spec = ""
for operand in arrays:
if len(spec):
Expand Down Expand Up @@ -289,3 +340,30 @@ def _assign_letter_to_name(name, name_mappings_for_einsum, used_letters):
name_mappings_for_einsum[name] = letter
used_letters.add(letter)
return letter


def _check_axis_size_consistency(arrays, axis, name_in_spec):
# ensure the size is correct and the axis is present
found = False
ax_name = axis_name(axis)
for array_index, array in enumerate(arrays):
try:
resolved = array.resolve_axis(ax_name)
except ValueError:
pass
else:
found = True
if resolved.size != axis.size:
raise ValueError(
f"Size mismatch for axis {ax_name}. In array {array_index},"
f" {axis} has size {resolved.size} but expected {axis.size},"
f"because of the alias {name_in_spec}={axis}"
)
if not found:
raise ValueError(f"Axis {ax_name} not found in any of the input arrays")


def _check_for_unused_aliases(axis_aliases, used_aliases, equation):
if any(alias not in used_aliases for alias in axis_aliases):
unused_aliases_str = ", ".join([alias for alias in axis_aliases if alias not in used_aliases])
raise_parse_error(f"Unused aliases from kwargs: {unused_aliases_str}", equation, None)
14 changes: 12 additions & 2 deletions src/haliax/_src/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,9 @@ class AliasTable:

def __init__(self, bindings=None):
if bindings is None:
bindings = {}
self.bindings = bindings
self.bindings = {}
else:
self.bindings = {**bindings}

def dealias_binding(self, binding: str) -> Optional[AxisSelector]:
return self.bindings.get(binding, None)
Expand All @@ -235,6 +236,15 @@ def bind_alias(self, alias: str, axis: Axis, expr, char_range):
if axis.name in self.bindings:
if self.bindings[alias] != axis:
raise_parse_error(f"Alias {alias} is assigned to more than one axis", expr, char_range)
elif alias in self.bindings:
current = self.bindings[alias]
if isinstance(current, Axis):
if current != axis:
raise_parse_error(f"Alias {alias} is assigned to more than one axis", expr, char_range)
elif current != axis.name:
raise_parse_error(f"Alias {alias} is assigned to more than one axis", expr, char_range)
else:
self.bindings[alias] = axis
else:
self.bindings[alias] = axis

Expand Down
2 changes: 2 additions & 0 deletions src/haliax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def resolve_axis(self, axes: AxisSelection) -> AxisSpec: # type: ignore
"""
Returns the axes corresponding to the given axis selection.
That is, it return the [haliax.Axis][] values themselves, not just their names.
Raises a ValueError if any of the axes are not found.
"""
indices = self._lookup_indices(axes)
if isinstance(indices, int):
Expand Down
Loading

0 comments on commit 5599b78

Please sign in to comment.