Skip to content

Commit

Permalink
allow ellipses in output-only einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Sep 15, 2024
1 parent 11689bd commit 30e3217
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/haliax/_src/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,13 @@ def _output_only_named_einsum(equation, arrays, rhs, axis_aliases):
used_aliases = set()

input_axis_names = set(ax.name for ax in _all_input_axes(arrays))
has_ellipsis = False

for capture in rhs.captures:
if capture is Ellipsis:
raise_parse_error("Can't use ellipsis on the rhs of an einsum without an lhs", equation, None)
# raise_parse_error("Can't use ellipsis on the rhs of an einsum without an lhs", equation, None)
out_axes.append(Ellipsis)
has_ellipsis = True
elif capture.binding is None or len(capture.axes) > 1:
raise_parse_error(
"Parenthesized axes are not currently supported in the output of an einsum",
Expand All @@ -150,7 +153,6 @@ def _output_only_named_einsum(equation, arrays, rhs, axis_aliases):
)

name = ax_name
used_axes.add(name)

if name in out_axes:
raise_parse_error(
Expand All @@ -160,8 +162,18 @@ def _output_only_named_einsum(equation, arrays, rhs, axis_aliases):
if name not in input_axis_names:
raise_parse_error(f"Axis {name} not found in any of the input arrays", equation, capture.char_range)

used_axes.add(name)
out_axes.append(name)

# if there's an ellipsis, put all unused axes in the ellipsis
if has_ellipsis:
all_input_axes = _all_input_axes(arrays)
unmentioned = [ax.name for ax in all_input_axes if ax.name not in used_axes]
ellipsis_index = out_axes.index(Ellipsis)
out_axes = out_axes[:ellipsis_index] + unmentioned + out_axes[ellipsis_index + 1 :]

used_axes = set(out_axes)

_check_for_unused_aliases(axis_aliases, used_aliases, equation)

spec = _make_einsum_spec(arrays, out_axes)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ def test_einsum_output_only_mode():
assert jnp.all(jnp.equal(einsum("-> Height Width", m1, m2).array, jnp.einsum("ijk,kji->ij", m1.array, m2.array)))
assert jnp.all(jnp.equal(einsum("-> Height", m1).array, jnp.einsum("ijk->i", m1.array)))

assert jnp.all(jnp.equal(einsum("-> ...", m1, m2).array, jnp.einsum("ijk,kji->ijk", m1.array, m2.array)))
assert jnp.all(jnp.equal(einsum("-> ... Width", m1, m2).array, jnp.einsum("ijk,kji->ikj", m1.array, m2.array)))
assert jnp.all(
jnp.equal(einsum("-> Depth ... Width", m1, m2).array, jnp.einsum("ijk,kji->kij", m1.array, m2.array))
)

with pytest.raises(ValueError):
einsum("-> Q Width", m1)

Expand Down

0 comments on commit 30e3217

Please sign in to comment.