Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement broadcast arrays #1410

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Pint Changelog
0.24 (unreleased)
-----------------

- Implement numpy broadcast_array (Related to issue #981)
- NumPy 2.0 support
(PR #1985, #1971)
- Implement numpy roll (Related to issue #981)
Expand Down
19 changes: 17 additions & 2 deletions pint/facets/numpy/numpy_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,23 @@ def _prod(a, *args, **kwargs):
implement_prod_func(name)


@implements("broadcast_arrays", "function")
def _broadcast_arrays(*args, **kwargs):
# Simply need to map input units to onto list of outputs
input_units = []
unitless_args = []
for x in args:
if hasattr(x, "units"):
input_units.append(x.units)
unitless_args.append(x.m)
else:
input_units.append(1)
unitless_args.append(x)

res = np.broadcast_arrays(*unitless_args, **kwargs)
return [out * unit for out, unit in zip(res, input_units)]


# Handle mutliplicative functions separately to deal with non-multiplicative units
def _base_unit_if_needed(a):
if a._is_multiplicative:
Expand Down Expand Up @@ -798,7 +815,6 @@ def implementation(a, b, **kwargs):
for func_str in ("cross", "dot"):
implement_mul_func(func_str)


# Implement simple matching-unit or stripped-unit functions based on signature


Expand Down Expand Up @@ -991,7 +1007,6 @@ def implementation(a, *args, **kwargs):
"vstack",
"dstack",
"column_stack",
"broadcast_arrays",
):
implement_func(
"function", func_str, input_units="all_consistent", output_unit="match_input"
Expand Down
5 changes: 5 additions & 0 deletions pint/testsuite/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ def test_broadcast_arrays(self):
result = np.broadcast_arrays(x, y, subok=True)
helpers.assert_quantity_equal(result, expected)

helpers.assert_quantity_equal(
np.broadcast_arrays(self.q[:, 1], np.array([1, 2])),
[np.array([[2, 4], [2, 4]]) * self.ureg.m, np.array([1, 2])],
)

def test_roll(self):
helpers.assert_quantity_equal(
np.roll(self.q, 1), [[4, 1], [2, 3]] * self.ureg.m
Expand Down
Loading