diff --git a/docs/matmul.md b/docs/matmul.md index 6f7303e..531e326 100644 --- a/docs/matmul.md +++ b/docs/matmul.md @@ -3,7 +3,10 @@ Haliax has two ways to do matrix multiplication (and tensor contractions more generally): [haliax.dot][] and [haliax.einsum][]. [haliax.dot][] and [haliax.einsum][] can both express any tensor contraction, though in different situations one or the other may be -more suitable for expressing a particular contraction. +more suitable for expressing a particular contraction In general: + +- Use [haliax.dot][] when you want to express a simple matrix multiplication over one or a few axes. +- Use [haliax.einsum][] when you want to express a more complex tensor contraction. See also the API reference for [haliax.dot][] and [haliax.einsum][] and the [cheat sheet section](cheatsheet.md#matrix-multiplication). @@ -68,6 +71,9 @@ Haliax's version of `einsum` comes in three modes: "ordered", "unordered", and " These modes are all accessible through the same function without any flags: the syntax of the `einsum` string determines which mode is used. +The syntax for Haliax's `einsum` is similar to [`haliax.rearrange`](rearrange.md), which +is in turn similar to [einops.rearrange](https://einops.rocks/api/rearrange/). + #### Ordered Mode Haliax's `einsum` has an "ordered" mode that is similar to `einops.einsum`'s behavior. @@ -119,6 +125,22 @@ y = hax.einsum("{H ...} -> ...", x) # shape is (W, D) This mode is most similar to [haliax.dot][]'s behavior, though it's a bit more expressive. +You can also use axis aliases in the `einsum` string, which can be useful for expressing contractions +in library code or just for shortening the string: + +```python +Height = hax.Axis("Height", 3) +Width = hax.Axis("Width", 4) +Depth = hax.Axis("Depth", 5) + +x = hax.ones((Height, Width, Depth)) +w = hax.ones((Depth,)) + +y = hax.einsum("{H W D} -> H W", x, H=Height, W=Width, D=Depth) # shape is (Height, Width) +y = hax.einsum("{D} -> ", w, D=Depth) # shape is (Height, Width) +``` + + #### Output Axes Mode In "output axes" mode, you only specify the axes that should be in the output. All other @@ -142,3 +164,16 @@ y = hax.einsum("-> D", w) # shape is (D,) We don't recommend using this mode except in cases when you're sure of the full shape of the input arrays or you are sure you don't want to let users implicitly batch over any axes. + +Output axes mode also supports axis aliases: + +```python +Height = hax.Axis("Height", 3) +Width = hax.Axis("Width", 4) +Depth = hax.Axis("Depth", 5) + +x = hax.ones((Height, Width, Depth)) +w = hax.ones((Depth,)) +y = hax.einsum("-> Height Width", x, Height=Height, Width=Width, Depth=Depth) # shape is (Height, Width) +y = hax.einsum("-> Depth", w, Depth=Depth) # shape is (Depth,) +``` diff --git a/src/haliax/_src/einsum.py b/src/haliax/_src/einsum.py index 4889d10..6a65f44 100644 --- a/src/haliax/_src/einsum.py +++ b/src/haliax/_src/einsum.py @@ -1,13 +1,12 @@ import functools from types import EllipsisType -from typing import Optional, Sequence, Tuple, Union +from typing import Optional, Tuple import jax.lax -import jax.numpy as jnp import haliax -from ..axis import AxisSelector, axis_name, eliminate_axes, rearrange_for_partial_order, union_axes +from ..axis import Axis, AxisSelector, axis_name, eliminate_axes, rearrange_for_partial_order, union_axes from ..core import NamedArray from ..jax_utils import _jittable_dg_einsum from ..types import DTypeLike, PrecisionLike @@ -21,6 +20,7 @@ def einsum( precision: PrecisionLike = None, preferred_element_type: Optional[DTypeLike] = None, _dot_general=jax.lax.dot_general, + **axis_aliases, ) -> NamedArray: """Compute the tensor contraction of the input arrays according to Haliax's named variant of the Einstein summation convention. @@ -38,12 +38,16 @@ def einsum( >>> hax.einsum("{H W D} -> H W", a, b) >>> hax.einsum("{D} -> ", a, b) # same as the previous example >>> hax.einsum("-> H W", a, b) # same as the first example + >>> # axis aliases, useful for generic code + >>> hax.einsum("{x y} -> y", a, b, x=H, y=W) Args: equation: The einsum equation. arrays: The input arrays. precision: The precision of the computation. preferred_element_type: The preferred element type of the computation. + _dot_general: The dot_general function to use. + axis_aliases: The axis aliases to use. Returns: The result of the einsum. @@ -59,17 +63,17 @@ def einsum( # NB: we're using JAX's einsum which only supports one letter names for dims if len(lhses) == 1 and len(lhses[0].captures) == 0 and lhses[0].is_ordered: # case 3: get the output axes, contract the others - spec, out_axes = _output_only_named_einsum(equation, arrays, rhs) + spec, out_axes = _output_only_named_einsum(equation, arrays, rhs, axis_aliases) elif len(lhses) == 1 and not lhses[0].is_ordered: # case 2: some axes are named. Those named only on the lhs are contracted, the others are kept # subcase: if there's an ellipsis on the lhs, we contract all the axes that are not named on the rhs - spec, out_axes = _unordered_einsum(arrays, equation, lhses, rhs) + spec, out_axes = _unordered_einsum(arrays, equation, lhses[0], rhs, axis_aliases) else: # general case: we have a normal einsum. we don't allow unordered axes here if any(not lhs.is_ordered for lhs in lhses): raise_parse_error("Cannot have multiple unordered axes in an einsum", equation, None) - spec, out_axes = _positional_einsum_spec(equation, arrays, lhses, rhs) + spec, out_axes = _positional_einsum_spec(equation, arrays, lhses, rhs, axis_aliases) out_raw = _jittable_dg_einsum( spec, @@ -83,12 +87,14 @@ def einsum( return haliax.auto_sharded(out) -def _unordered_einsum(arrays, equation, lhses, rhs): - used_letters: set[str] = set() - name_mappings_for_einsum: dict[str, str] = {} - lhs = lhses[0] - candidate_axes, has_ellipsis_lhs = _captures_to_axis_names(equation, lhs) - rhs_axes, has_ellipsis_rhs = _captures_to_axis_names(equation, rhs) +def _unordered_einsum(arrays, equation, lhs, rhs, axis_aliases): + candidate_axes, has_ellipsis_lhs, covered_lhs = _captures_to_axis_names(equation, lhs, axis_aliases) + rhs_axes, has_ellipsis_rhs, covered_rhs = _captures_to_axis_names(equation, rhs, axis_aliases) + + for alias_name in axis_aliases: + if alias_name not in covered_lhs and alias_name not in covered_rhs: + raise_parse_error(f"Axis alias {alias_name} not used in the einsum", equation, None) + all_input_axes = _all_input_axes(arrays) if has_ellipsis_rhs: out_axes = rearrange_for_partial_order(rhs_axes, all_input_axes) @@ -105,15 +111,17 @@ def _unordered_einsum(arrays, equation, lhses, rhs): # what people expect rhs_axes = [Ellipsis] + rhs_axes # type: ignore out_axes = rearrange_for_partial_order(rhs_axes, almost_out_axes) - spec = _make_einsum_spec(name_mappings_for_einsum, used_letters, arrays, out_axes) + spec = _make_einsum_spec(arrays, out_axes) return spec, out_axes -def _output_only_named_einsum(equation, arrays, rhs): - used_letters: set[str] = set() - name_mappings_for_einsum: dict[str, str] = {} - +def _output_only_named_einsum(equation, arrays, rhs, axis_aliases): out_axes = [] + used_axes = set() + used_aliases = set() + + input_axis_names = set(ax.name for ax in _all_input_axes(arrays)) + for capture in rhs.captures: if capture is Ellipsis: raise_parse_error("Can't use ellipsis on the rhs of an einsum without an lhs", equation, None) @@ -125,23 +133,51 @@ def _output_only_named_einsum(equation, arrays, rhs): ) else: name = capture.binding + used_aliases.add(name) + + if name in axis_aliases: + # this could be axis or a name. if an axis, need to assert the size + axis = axis_aliases[name] + if isinstance(axis, Axis): + _check_axis_size_consistency(arrays, axis, name) + ax_name = axis_name(axis) + + if ax_name in used_axes: + raise_parse_error( + f"Axis {name} occurs multiple times on the rhs. Probably because of multiple aliasing?", + equation, + capture.char_range, + ) + + name = ax_name + used_axes.add(name) if name in out_axes: - raise_parse_error(f"Axis name {name} occurs multiple times on the rhs", equation, capture.char_range) + raise_parse_error( + f"Axis capture {name} occurs multiple times on the rhs", equation, capture.char_range + ) + + if name not in input_axis_names: + raise_parse_error(f"Axis {name} not found in any of the input arrays", equation, capture.char_range) out_axes.append(name) - spec = _make_einsum_spec(name_mappings_for_einsum, used_letters, arrays, out_axes) + _check_for_unused_aliases(axis_aliases, used_aliases, equation) + + spec = _make_einsum_spec(arrays, out_axes) return spec, out_axes -def _positional_einsum_spec(equation, arrays, lhses, rhs): +def _positional_einsum_spec(equation, arrays, lhses, rhs, axis_aliases): used_letters: set[str] = set() name_mappings_for_einsum: dict[str, str] = {} + used_aliases = set() if len(lhses) != len(arrays): raise ValueError(f"Number of lhses ({len(lhses)}) does not match number of arrays ({len(arrays)})") - table = AliasTable() + + # For this function, axis_aliases exists entirely for checking axis sizes against what's in the arrays + table = AliasTable(axis_aliases) # ok, we're going to lead pretty heavily on einsum here. We just need to figure out the names of the axes # and do any error checking (that there are no mismatched names) # once we do that, we can pass a slightly modified spec to einsum (namely that we shorten the names of the axes) @@ -164,6 +200,9 @@ def _positional_einsum_spec(equation, arrays, lhses, rhs): raise_parse_error("Parenthesized axes are not currently supported", equation, capture.char_range) else: name = capture.binding + if name in axis_aliases: + used_aliases.add(name) + if axis_off >= len(a.axes): raise ValueError("Mismatched number of axes in einsum") table.bind_alias(name, a.axes[axis_off], equation, capture.char_range) @@ -184,6 +223,9 @@ def _positional_einsum_spec(equation, arrays, lhses, rhs): break else: name = capture.binding + if name in axis_aliases: + used_aliases.add(name) + if axis_off < final_lhs_axis_off: raise ValueError("Mismatched number of axes in einsum") table.bind_alias(name, a.axes[axis_off], equation, capture.char_range) @@ -230,6 +272,8 @@ def _positional_einsum_spec(equation, arrays, lhses, rhs): spec += letter out_axes.append(axis) + _check_for_unused_aliases(axis_aliases, used_aliases, equation) + if has_ellipsis_rhs: all_input_axes = _all_input_axes(arrays) # eliminate the axes that are contracted @@ -244,7 +288,8 @@ def _all_input_axes(arrays): return ensure_tuple(functools.reduce(union_axes, (a.axes for a in arrays), ())) # type: ignore -def _captures_to_axis_names(equation, lhs) -> Tuple[list[str | EllipsisType], bool]: +def _captures_to_axis_names(equation, lhs, aliases) -> Tuple[list[str | EllipsisType], bool, set[str]]: + covered_aliases = set() candidate_axes: list[str | EllipsisType] = [] has_ellipsis = False for capture in lhs.captures: @@ -255,11 +300,17 @@ def _captures_to_axis_names(equation, lhs) -> Tuple[list[str | EllipsisType], bo raise_parse_error("Parenthesized axes are not currently supported", equation, capture.char_range) else: name = capture.binding + if name in aliases: + covered_aliases.add(name) + axis = aliases[name] + name = axis_name(axis) candidate_axes.append(name) - return candidate_axes, has_ellipsis + return candidate_axes, has_ellipsis, covered_aliases -def _make_einsum_spec(name_mappings_for_einsum, used_letters, arrays, out_axes): +def _make_einsum_spec(arrays, out_axes): + name_mappings_for_einsum: dict[str, str] = {} + used_letters: set[str] = set() spec = "" for operand in arrays: if len(spec): @@ -289,3 +340,30 @@ def _assign_letter_to_name(name, name_mappings_for_einsum, used_letters): name_mappings_for_einsum[name] = letter used_letters.add(letter) return letter + + +def _check_axis_size_consistency(arrays, axis, name_in_spec): + # ensure the size is correct and the axis is present + found = False + ax_name = axis_name(axis) + for array_index, array in enumerate(arrays): + try: + resolved = array.resolve_axis(ax_name) + except ValueError: + pass + else: + found = True + if resolved.size != axis.size: + raise ValueError( + f"Size mismatch for axis {ax_name}. In array {array_index}," + f" {axis} has size {resolved.size} but expected {axis.size}," + f"because of the alias {name_in_spec}={axis}" + ) + if not found: + raise ValueError(f"Axis {ax_name} not found in any of the input arrays") + + +def _check_for_unused_aliases(axis_aliases, used_aliases, equation): + if any(alias not in used_aliases for alias in axis_aliases): + unused_aliases_str = ", ".join([alias for alias in axis_aliases if alias not in used_aliases]) + raise_parse_error(f"Unused aliases from kwargs: {unused_aliases_str}", equation, None) diff --git a/src/haliax/_src/parsing.py b/src/haliax/_src/parsing.py index 98c0267..4bce2d3 100644 --- a/src/haliax/_src/parsing.py +++ b/src/haliax/_src/parsing.py @@ -225,8 +225,9 @@ class AliasTable: def __init__(self, bindings=None): if bindings is None: - bindings = {} - self.bindings = bindings + self.bindings = {} + else: + self.bindings = {**bindings} def dealias_binding(self, binding: str) -> Optional[AxisSelector]: return self.bindings.get(binding, None) @@ -235,6 +236,15 @@ def bind_alias(self, alias: str, axis: Axis, expr, char_range): if axis.name in self.bindings: if self.bindings[alias] != axis: raise_parse_error(f"Alias {alias} is assigned to more than one axis", expr, char_range) + elif alias in self.bindings: + current = self.bindings[alias] + if isinstance(current, Axis): + if current != axis: + raise_parse_error(f"Alias {alias} is assigned to more than one axis", expr, char_range) + elif current != axis.name: + raise_parse_error(f"Alias {alias} is assigned to more than one axis", expr, char_range) + else: + self.bindings[alias] = axis else: self.bindings[alias] = axis diff --git a/src/haliax/core.py b/src/haliax/core.py index c08bba6..c61c572 100644 --- a/src/haliax/core.py +++ b/src/haliax/core.py @@ -184,6 +184,8 @@ def resolve_axis(self, axes: AxisSelection) -> AxisSpec: # type: ignore """ Returns the axes corresponding to the given axis selection. That is, it return the [haliax.Axis][] values themselves, not just their names. + + Raises a ValueError if any of the axes are not found. """ indices = self._lookup_indices(axes) if isinstance(indices, int): diff --git a/tests/test_einsum.py b/tests/test_einsum.py index e73031f..992bf6c 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -37,6 +37,28 @@ def test_einsum_basic_positional(): ) +def test_einsum_positional_aliases(): + Height = Axis("Height", 2) + Width = Axis("Width", 3) + Depth = Axis("Depth", 4) + + m1 = NamedArray(jnp.ones((Height.size, Width.size, Depth.size)), (Height, Width, Depth)) + m2 = NamedArray(jnp.ones((Depth.size, Width.size, Height.size)), (Depth, Width, Height)) + + assert jnp.all( + jnp.equal(einsum("i j k,k j i-> j k", m1, m2, i=Height).array, jnp.einsum("ijk,kji->jk", m1.array, m2.array)) + ) + + with pytest.raises(ValueError): + einsum("i j k,k j i-> j k", m1, m2, i=Width) + + with pytest.raises(ValueError): + einsum("i j k,q j i-> j k", m1, m2, i=Height, q=Height) + + with pytest.raises(ValueError): + einsum("i j k,k j i-> j k", m1, m2, i=Height, q=Height) + + def test_einsum_basic_named(): Height = Axis("Height", 2) Width = Axis("Width", 3) @@ -148,6 +170,32 @@ def test_einsum_unordered_ellipses(): ) +def test_einsum_unordered_aliases(): + Height = Axis("Height", 2) + Width = Axis("Width", 3) + Depth = Axis("Depth", 4) + + m1 = hax.ones((Height, Width, Depth)) + m2 = hax.ones((Depth, Width, Height)) + + assert jnp.all( + jnp.equal( + einsum("{h w d} -> h w", m1, m2, h=Height, w=Width, d=Depth).array, + jnp.einsum("ijk,kji->ij", m1.array, m2.array), + ) + ) + + # test error cases: + + # Missing alias + with pytest.raises(ValueError, match="Axis d not present"): + einsum("{h w d} -> h w", m1, m2, h=Height, w=Width) + + # Extra alias + with pytest.raises(ValueError, match="Axis alias d not used"): + einsum("{h w} -> h w", m1, m2, h=Height, w=Width, d=Depth) + + def test_einsum_ordered_ellipsis(): Height = Axis("Height", 2) Width = Axis("Width", 3) @@ -272,3 +320,35 @@ def test_einsum_examples(): hax_out = hax.einsum("{...} -> ", hax_im, hax_w2) jnp_out = jnp.einsum("bhwc,ce -> ", im, w2) assert jnp.all(jnp.equal(hax_out.array, jnp_out)) + + +def test_einsum_output_only_mode(): + # tests "-> out axes" + Height = Axis("Height", 2) + Width = Axis("Width", 3) + Depth = Axis("Depth", 4) + + m1 = hax.ones((Height, Width, Depth)) + m2 = hax.ones((Depth, Width, Height)) + m3 = hax.ones((Height, Depth)) + + assert jnp.all(jnp.equal(einsum("-> Height Width", m1, m2).array, jnp.einsum("ijk,kji->ij", m1.array, m2.array))) + assert jnp.all(jnp.equal(einsum("-> Height", m1).array, jnp.einsum("ijk->i", m1.array))) + + with pytest.raises(ValueError): + einsum("-> Q Width", m1) + + with pytest.raises(ValueError, match=".*Unused aliases from kwargs: Q$"): + einsum("-> Height Width", m1, m2, Q=Axis("Q", 2)) + + assert jnp.all(jnp.equal(einsum("-> h w", m1, h=Height, w=Width).array, jnp.einsum("ijk->ij", m1.array))) + + assert jnp.all( + jnp.equal(einsum("-> h w", m1, m3, h=Height, w=Width).array, jnp.einsum("ijk,ik->ij", m1.array, m3.array)) + ) + + with pytest.raises(ValueError, match=".*Size mismatch.*"): + einsum("-> h w", m1, h=Height.resize(4), w=Width) + + with pytest.raises(ValueError, match=".*not found in any of the input arrays.*"): + einsum("-> h w", m3, h=Height, w=Width.resize(4))