From 30e3217fce7284c7d3392dcfae8723c6f2d88ef2 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 14 Sep 2024 22:39:14 -0700 Subject: [PATCH] allow ellipses in output-only einsum --- src/haliax/_src/einsum.py | 16 ++++++++++++++-- tests/test_einsum.py | 6 ++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/haliax/_src/einsum.py b/src/haliax/_src/einsum.py index 6a65f44..326b1d7 100644 --- a/src/haliax/_src/einsum.py +++ b/src/haliax/_src/einsum.py @@ -121,10 +121,13 @@ def _output_only_named_einsum(equation, arrays, rhs, axis_aliases): used_aliases = set() input_axis_names = set(ax.name for ax in _all_input_axes(arrays)) + has_ellipsis = False for capture in rhs.captures: if capture is Ellipsis: - raise_parse_error("Can't use ellipsis on the rhs of an einsum without an lhs", equation, None) + # raise_parse_error("Can't use ellipsis on the rhs of an einsum without an lhs", equation, None) + out_axes.append(Ellipsis) + has_ellipsis = True elif capture.binding is None or len(capture.axes) > 1: raise_parse_error( "Parenthesized axes are not currently supported in the output of an einsum", @@ -150,7 +153,6 @@ def _output_only_named_einsum(equation, arrays, rhs, axis_aliases): ) name = ax_name - used_axes.add(name) if name in out_axes: raise_parse_error( @@ -160,8 +162,18 @@ def _output_only_named_einsum(equation, arrays, rhs, axis_aliases): if name not in input_axis_names: raise_parse_error(f"Axis {name} not found in any of the input arrays", equation, capture.char_range) + used_axes.add(name) out_axes.append(name) + # if there's an ellipsis, put all unused axes in the ellipsis + if has_ellipsis: + all_input_axes = _all_input_axes(arrays) + unmentioned = [ax.name for ax in all_input_axes if ax.name not in used_axes] + ellipsis_index = out_axes.index(Ellipsis) + out_axes = out_axes[:ellipsis_index] + unmentioned + out_axes[ellipsis_index + 1 :] + + used_axes = set(out_axes) + _check_for_unused_aliases(axis_aliases, used_aliases, equation) spec = _make_einsum_spec(arrays, out_axes) diff --git a/tests/test_einsum.py b/tests/test_einsum.py index ca909bb..6790fee 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -330,6 +330,12 @@ def test_einsum_output_only_mode(): assert jnp.all(jnp.equal(einsum("-> Height Width", m1, m2).array, jnp.einsum("ijk,kji->ij", m1.array, m2.array))) assert jnp.all(jnp.equal(einsum("-> Height", m1).array, jnp.einsum("ijk->i", m1.array))) + assert jnp.all(jnp.equal(einsum("-> ...", m1, m2).array, jnp.einsum("ijk,kji->ijk", m1.array, m2.array))) + assert jnp.all(jnp.equal(einsum("-> ... Width", m1, m2).array, jnp.einsum("ijk,kji->ikj", m1.array, m2.array))) + assert jnp.all( + jnp.equal(einsum("-> Depth ... Width", m1, m2).array, jnp.einsum("ijk,kji->kij", m1.array, m2.array)) + ) + with pytest.raises(ValueError): einsum("-> Q Width", m1)