From 863d0e7e4c2626fe1d4f931515ccce61788a24c8 Mon Sep 17 00:00:00 2001 From: Dean MacGregor Date: Tue, 6 Feb 2024 09:46:50 -0500 Subject: [PATCH 1/4] multi_arg_ufunc --- py-polars/polars/expr/expr.py | 37 +++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index f1e9f0c89e7e..3dd613576c45 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -287,21 +287,34 @@ def __array_ufunc__( self, ufunc: Callable[..., Any], method: str, *inputs: Any, **kwargs: Any ) -> Self: """Numpy universal functions.""" + if method != "__call__": + msg = f"Only call is implemented not {method}" + raise NotImplementedError(msg) is_custom_ufunc = ufunc.__class__ != np.ufunc num_expr = sum(isinstance(inp, Expr) for inp in inputs) - if num_expr > 1: - if num_expr < len(inputs): - msg = ( - "NumPy ufunc with more than one expression can only be used" - " if all non-expression inputs are provided as keyword arguments only" - ) - raise ValueError(msg) + exprs = [ + (inp, Expr, i) if isinstance(inp, Expr) else (inp, None, i) + for i, inp in enumerate(inputs) + ] + if num_expr == 1: + root_expr = next(expr[0] for expr in exprs if expr[1] == Expr) - exprs = parse_as_list_of_expressions(inputs) - return self._from_pyexpr(pyreduce(partial(ufunc, **kwargs), exprs)) + # def function(s: Series) -> Series: + # return ufunc(s, **kwargs) + else: + root_expr = F.struct( + expr[0].alias(f"__arg{expr[2]}") for expr in exprs if expr[1] == Expr + ) def function(s: Series) -> Series: # pragma: no cover - args = [inp if not isinstance(inp, Expr) else s for inp in inputs] + args = [] + for expr in exprs: + if expr[1] == Expr and num_expr > 1: + args.append(s.struct.field(f"__arg{expr[2]}")) + elif expr[1] == Expr: + args.append(s) + else: + args.append(expr[0]) return ufunc(*args, **kwargs) if is_custom_ufunc is True: @@ -316,8 +329,8 @@ def function(s: Series) -> Series: # pragma: no cover CustomUFuncWarning, stacklevel=find_stacklevel(), ) - return self.map_batches(function, is_elementwise=False) - return self.map_batches(function, is_elementwise=True) + return root_expr.map_batches(function, is_elementwise=False) + return root_expr.map_batches(function, is_elementwise=True) @classmethod def from_json(cls, value: str) -> Self: From 7708c9c97f6a160c66c49bb6d48efec37b216904 Mon Sep 17 00:00:00 2001 From: Dean MacGregor Date: Tue, 6 Feb 2024 17:33:33 -0500 Subject: [PATCH 2/4] works --- py-polars/polars/expr/expr.py | 20 ++++++++----------- .../unit/operations/map/test_map_batches.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 3dd613576c45..760e860ff114 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -6,7 +6,7 @@ import os import warnings from datetime import timedelta -from functools import partial, reduce +from functools import reduce from typing import ( TYPE_CHECKING, Any, @@ -66,7 +66,6 @@ with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import arg_where as py_arg_where - from polars.polars import reduce as pyreduce with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyExpr @@ -298,19 +297,14 @@ def __array_ufunc__( ] if num_expr == 1: root_expr = next(expr[0] for expr in exprs if expr[1] == Expr) - - # def function(s: Series) -> Series: - # return ufunc(s, **kwargs) else: - root_expr = F.struct( - expr[0].alias(f"__arg{expr[2]}") for expr in exprs if expr[1] == Expr - ) + root_expr = F.struct(expr[0] for expr in exprs if expr[1] == Expr) def function(s: Series) -> Series: # pragma: no cover args = [] - for expr in exprs: + for i, expr in enumerate(exprs): if expr[1] == Expr and num_expr > 1: - args.append(s.struct.field(f"__arg{expr[2]}")) + args.append(s.struct[i]) elif expr[1] == Expr: args.append(s) else: @@ -329,8 +323,10 @@ def function(s: Series) -> Series: # pragma: no cover CustomUFuncWarning, stacklevel=find_stacklevel(), ) - return root_expr.map_batches(function, is_elementwise=False) - return root_expr.map_batches(function, is_elementwise=True) + return root_expr.map_batches( + function, is_elementwise=False + ).meta.undo_aliases() + return root_expr.map_batches(function, is_elementwise=True).meta.undo_aliases() @classmethod def from_json(cls, value: str) -> Self: diff --git a/py-polars/tests/unit/operations/map/test_map_batches.py b/py-polars/tests/unit/operations/map/test_map_batches.py index 2cde056cc652..a85c06e1e8c4 100644 --- a/py-polars/tests/unit/operations/map/test_map_batches.py +++ b/py-polars/tests/unit/operations/map/test_map_batches.py @@ -77,3 +77,13 @@ def test_map_deprecated() -> None: pl.col("a").map(lambda x: x) with pytest.deprecated_call(): pl.LazyFrame({"a": [1, 2]}).map(lambda x: x) + + +def test_ufunc_args() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]}) + result = df.select(z=np.add(pl.col("a"), pl.col("b"))) + expected = pl.DataFrame({"z": [3, 6, 9]}) + assert_frame_equal(result, expected) + result = df.select(z=np.add(2, pl.col("a"))) + expected = pl.DataFrame({"z": [3, 4, 5]}) + assert_frame_equal(result, expected) From 2e698f4992056fad670e6eb78d6774eaaa113d4b Mon Sep 17 00:00:00 2001 From: deanm0000 <37878412+deanm0000@users.noreply.github.com> Date: Tue, 6 Feb 2024 20:50:40 -0500 Subject: [PATCH 3/4] Update test_map_batches.py --- py-polars/tests/unit/operations/map/test_map_batches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py-polars/tests/unit/operations/map/test_map_batches.py b/py-polars/tests/unit/operations/map/test_map_batches.py index a85c06e1e8c4..f966bcd5b1ee 100644 --- a/py-polars/tests/unit/operations/map/test_map_batches.py +++ b/py-polars/tests/unit/operations/map/test_map_batches.py @@ -81,9 +81,9 @@ def test_map_deprecated() -> None: def test_ufunc_args() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]}) - result = df.select(z=np.add(pl.col("a"), pl.col("b"))) + result = df.select(z=np.add(pl.col("a"), pl.col("b"))) # type: ignore[call-overload] expected = pl.DataFrame({"z": [3, 6, 9]}) assert_frame_equal(result, expected) - result = df.select(z=np.add(2, pl.col("a"))) + result = df.select(z=np.add(2, pl.col("a"))) # type: ignore[call-overload] expected = pl.DataFrame({"z": [3, 4, 5]}) assert_frame_equal(result, expected) From 90386fbebf0034f0abbd0d40ea5fc7cec32aed0e Mon Sep 17 00:00:00 2001 From: Dean MacGregor Date: Wed, 7 Feb 2024 07:48:36 -0500 Subject: [PATCH 4/4] ruff --- .../tests/unit/operations/map/test_map_batches.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/py-polars/tests/unit/operations/map/test_map_batches.py b/py-polars/tests/unit/operations/map/test_map_batches.py index f966bcd5b1ee..457df189fa00 100644 --- a/py-polars/tests/unit/operations/map/test_map_batches.py +++ b/py-polars/tests/unit/operations/map/test_map_batches.py @@ -81,9 +81,17 @@ def test_map_deprecated() -> None: def test_ufunc_args() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]}) - result = df.select(z=np.add(pl.col("a"), pl.col("b"))) # type: ignore[call-overload] + result = df.select( + z=np.add( # type: ignore[call-overload] + pl.col("a"), pl.col("b") + ) + ) expected = pl.DataFrame({"z": [3, 6, 9]}) assert_frame_equal(result, expected) - result = df.select(z=np.add(2, pl.col("a"))) # type: ignore[call-overload] + result = df.select( + z=np.add( # type: ignore[call-overload] + 2, pl.col("a") + ) + ) expected = pl.DataFrame({"z": [3, 4, 5]}) assert_frame_equal(result, expected)