Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): fix ufunc for unlimited column args #14328

Merged
merged 4 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import warnings
from datetime import timedelta
from functools import partial, reduce
from functools import reduce
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -66,7 +66,6 @@

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import arg_where as py_arg_where
from polars.polars import reduce as pyreduce

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import PyExpr
Expand Down Expand Up @@ -287,21 +286,29 @@ def __array_ufunc__(
self, ufunc: Callable[..., Any], method: str, *inputs: Any, **kwargs: Any
) -> Self:
"""Numpy universal functions."""
if method != "__call__":
msg = f"Only call is implemented not {method}"
raise NotImplementedError(msg)
is_custom_ufunc = ufunc.__class__ != np.ufunc
num_expr = sum(isinstance(inp, Expr) for inp in inputs)
if num_expr > 1:
if num_expr < len(inputs):
msg = (
"NumPy ufunc with more than one expression can only be used"
" if all non-expression inputs are provided as keyword arguments only"
)
raise ValueError(msg)

exprs = parse_as_list_of_expressions(inputs)
return self._from_pyexpr(pyreduce(partial(ufunc, **kwargs), exprs))
exprs = [
(inp, Expr, i) if isinstance(inp, Expr) else (inp, None, i)
for i, inp in enumerate(inputs)
]
if num_expr == 1:
root_expr = next(expr[0] for expr in exprs if expr[1] == Expr)
else:
root_expr = F.struct(expr[0] for expr in exprs if expr[1] == Expr)

def function(s: Series) -> Series: # pragma: no cover
args = [inp if not isinstance(inp, Expr) else s for inp in inputs]
args = []
for i, expr in enumerate(exprs):
if expr[1] == Expr and num_expr > 1:
args.append(s.struct[i])
elif expr[1] == Expr:
args.append(s)
else:
args.append(expr[0])
return ufunc(*args, **kwargs)

if is_custom_ufunc is True:
Expand All @@ -316,8 +323,10 @@ def function(s: Series) -> Series: # pragma: no cover
CustomUFuncWarning,
stacklevel=find_stacklevel(),
)
return self.map_batches(function, is_elementwise=False)
return self.map_batches(function, is_elementwise=True)
return root_expr.map_batches(
function, is_elementwise=False
).meta.undo_aliases()
return root_expr.map_batches(function, is_elementwise=True).meta.undo_aliases()

@classmethod
def from_json(cls, value: str) -> Self:
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/operations/map/test_map_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,21 @@ def test_map_deprecated() -> None:
pl.col("a").map(lambda x: x)
with pytest.deprecated_call():
pl.LazyFrame({"a": [1, 2]}).map(lambda x: x)


def test_ufunc_args() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]})
result = df.select(
z=np.add( # type: ignore[call-overload]
pl.col("a"), pl.col("b")
)
)
expected = pl.DataFrame({"z": [3, 6, 9]})
assert_frame_equal(result, expected)
result = df.select(
z=np.add( # type: ignore[call-overload]
2, pl.col("a")
)
)
expected = pl.DataFrame({"z": [3, 4, 5]})
assert_frame_equal(result, expected)
Loading