diff --git a/docs/api.md b/docs/api.md index 743a97a..2f47106 100644 --- a/docs/api.md +++ b/docs/api.md @@ -31,10 +31,11 @@ Occasionally, an axis size can be inferred in some circumstances but not others. ::: haliax.axis.axis_name ::: haliax.axis.concat_axes ::: haliax.axis.union_axes +::: haliax.axis.intersect_axes ::: haliax.axis.eliminate_axes ::: haliax.axis.without_axes -::: haliax.axis.overlapping_axes ::: haliax.axis.selects_axis +::: haliax.axis.overlapping_axes ::: haliax.axis.is_axis_compatible diff --git a/src/haliax/axis.py b/src/haliax/axis.py index 49463c9..1268659 100644 --- a/src/haliax/axis.py +++ b/src/haliax/axis.py @@ -297,21 +297,21 @@ def replace_axis(axis_spec: AxisSelection, old: AxisSelector, new: AxisSelection @overload -def overlapping_axes(ax1: AxisSpec, ax2: AxisSelection) -> Tuple[Axis, ...]: +def intersect_axes(ax1: AxisSpec, ax2: AxisSelection) -> Tuple[Axis, ...]: ... @overload -def overlapping_axes(ax1: AxisSelection, ax2: AxisSpec) -> Tuple[Axis, ...]: +def intersect_axes(ax1: AxisSelection, ax2: AxisSpec) -> Tuple[Axis, ...]: ... @overload -def overlapping_axes(ax1: AxisSelection, ax2: AxisSelection) -> Tuple[AxisSelector, ...]: +def intersect_axes(ax1: AxisSelection, ax2: AxisSelection) -> Tuple[AxisSelector, ...]: ... -def overlapping_axes(ax1: AxisSelection, ax2: AxisSelection) -> Tuple[AxisSelector, ...]: +def intersect_axes(ax1: AxisSelection, ax2: AxisSelection) -> Tuple[AxisSelector, ...]: """Returns a tuple of axes that are present in both ax1 and ax2. The returned order is the same as ax1. """ @@ -339,6 +339,23 @@ def overlapping_axes(ax1: AxisSelection, ax2: AxisSelection) -> Tuple[AxisSelect return tuple(out) +def overlapping_axes(ax1: AxisSelection, ax2: AxisSelection) -> Tuple[str, ...]: + """ + Like intersect_axes, but returns the names instead of the axes themselves. + Unlike intersect_axes, it does not throw an error if the sizes of a common axis are + different. + + The returned order is the same as in ax1. + """ + ax1 = ensure_tuple(ax1) + ax2 = ensure_tuple(ax2) + ax1_names = map(axis_name, ax1) + ax2_names = set(map(axis_name, ax2)) + + out = tuple(name for name in ax1_names if name in ax2_names) + return out + + @overload def axis_name(ax: AxisSelector) -> str: # type: ignore ... @@ -555,6 +572,7 @@ def replace_missing_with_ellipsis(ax1: AxisSelection, ax2: AxisSelection) -> Par "dslice", "dblock", "eliminate_axes", + "intersect_axes", "is_axis_compatible", "overlapping_axes", "replace_axis", diff --git a/src/haliax/core.py b/src/haliax/core.py index 264097d..536c8d4 100644 --- a/src/haliax/core.py +++ b/src/haliax/core.py @@ -689,9 +689,9 @@ def take(array: NamedArray, axis: AxisSelector, index: Union[int, NamedArray]) - remaining_axes = eliminate_axes(array.axes, axis) # axis order is generally [array.axes[:axis_index], index.axes, array.axes[axis_index + 1 :]] # except that index.axes may overlap with array.axes - overlapping_axes: AxisSpec = haliax.axis.overlapping_axes(remaining_axes, index.axes) + intersecting_axes: AxisSpec = haliax.axis.intersect_axes(remaining_axes, index.axes) - if overlapping_axes: + if intersecting_axes: # if the eliminated axis is also in the index, we rename it to a dummy axis that we can broadcast over it need_to_use_dummy_axis = index._lookup_indices(axis.name) is not None if need_to_use_dummy_axis: diff --git a/tests/test_axis.py b/tests/test_axis.py index 4d9eb0f..d7a9def 100644 --- a/tests/test_axis.py +++ b/tests/test_axis.py @@ -1,6 +1,6 @@ import pytest -from haliax.axis import Axis, eliminate_axes, make_axes, rearrange_for_partial_order +from haliax.axis import Axis, eliminate_axes, make_axes, overlapping_axes, rearrange_for_partial_order def test_eliminate_axes(): @@ -133,3 +133,17 @@ def test_duplicate_elements_errors(): with pytest.raises(ValueError): rearrange_for_partial_order(partial_order, candidates) + + +def test_overlapping_axes_with_different_sizes(): + A1 = Axis("A", 10) + A2 = Axis("A", 12) + B = Axis("B", 14) + C = Axis("C", 16) + D = Axis("D", 18) + + ax1 = (A1, B, C) + ax2 = (A2, C, D) + + overlapping_names = overlapping_axes(ax1, ax2) # Should not error + assert overlapping_names == ("A", "C")