Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): improve support for user-defined functions that return scalars #16556

Merged
merged 13 commits into from
May 30, 2024
25 changes: 21 additions & 4 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,21 @@ pub struct PythonUdfExpression {
python_function: PyObject,
output_type: Option<DataType>,
is_elementwise: bool,
returns_scalar: bool,
}

impl PythonUdfExpression {
pub fn new(lambda: PyObject, output_type: Option<DataType>, is_elementwise: bool) -> Self {
pub fn new(
lambda: PyObject,
output_type: Option<DataType>,
is_elementwise: bool,
returns_scalar: bool,
) -> Self {
Self {
python_function: lambda,
output_type,
is_elementwise,
returns_scalar,
}
}

Expand All @@ -121,7 +128,7 @@ impl PythonUdfExpression {
// skip header
let buf = &buf[MAGIC_BYTE_MARK.len()..];
let mut reader = Cursor::new(buf);
let (output_type, is_elementwise): (Option<DataType>, bool) =
let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) =
ciborium::de::from_reader(&mut reader).map_err(map_err)?;

let remainder = &buf[reader.position() as usize..];
Expand All @@ -138,6 +145,7 @@ impl PythonUdfExpression {
python_function.into(),
output_type,
is_elementwise,
returns_scalar,
)) as Arc<dyn SeriesUdf>)
})
}
Expand Down Expand Up @@ -181,8 +189,15 @@ impl SeriesUdf for PythonUdfExpression {
#[cfg(feature = "serde")]
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
buf.extend_from_slice(MAGIC_BYTE_MARK);
ciborium::ser::into_writer(&(self.output_type.clone(), self.is_elementwise), &mut *buf)
.unwrap();
ciborium::ser::into_writer(
&(
self.output_type.clone(),
self.is_elementwise,
self.returns_scalar,
),
&mut *buf,
)
.unwrap();

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
Expand Down Expand Up @@ -222,6 +237,7 @@ impl Expr {
(ApplyOptions::GroupWise, "python_udf")
};

let returns_scalar = func.returns_scalar;
let return_dtype = func.output_type.clone();
let output_type = GetOutput::map_field(move |fld| match return_dtype {
Some(ref dt) => Field::new(fld.name(), dt.clone()),
Expand All @@ -239,6 +255,7 @@ impl Expr {
options: FunctionOptions {
collect_groups,
fmt_str: name,
returns_scalar,
..Default::default()
},
}
Expand Down
56 changes: 44 additions & 12 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4524,21 +4524,22 @@ def map_batches(
*,
agg_list: bool = False,
is_elementwise: bool = False,
returns_scalar: bool = False,
) -> Self:
"""
Apply a custom python function to a whole Series or sequence of Series.

The output of this custom function must be a Series (or a NumPy array, in which
case it will be automatically converted into a Series). If you want to apply a
The output of this custom function is presumed to be either a Series,
or a NumPy array (in which case it will be automatically converted into
a Series), or a scalar that will be converted into a Series. If the
result is a scalar and you want it to stay as a scalar, pass in
``returns_scalar=True``. If you want to apply a
custom function elementwise over single values, see :func:`map_elements`.
A reasonable use case for `map` functions is transforming the values
represented by an expression using a third-party library.

.. warning::
If you are looking to map a function over a window function or group_by
context, refer to :func:`map_elements` instead.
Read more in `the book
<https://docs.pola.rs/user-guide/expressions/user-defined-functions>`_.
If your function returns a scalar, for example a float, use
:func:`map_to_scalar` instead.

Parameters
----------
Expand All @@ -4556,6 +4557,11 @@ def map_batches(
function. This parameter only works in a group-by context.
The function will be invoked only once on a list of groups, rather than
once per group.
returns_scalar
If the function returns a scalar, by default it will be wrapped in
a list in the output, since the assumption is that the function
always returns something Series-like. If you want to keep the
result as a scalar, set this argument to True.

Warnings
--------
Expand Down Expand Up @@ -4597,33 +4603,58 @@ def map_batches(
... }
... )
>>> df.group_by("a").agg(
... pl.col("b").map_batches(lambda x: x.max(), agg_list=False)
... pl.col("b").map_batches(lambda x: x + 2, agg_list=False)
... ) # doctest: +IGNORE_RESULT
shape: (2, 2)
┌─────┬───────────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ list[i64] │
╞═════╪═══════════╡
│ 1 ┆ [4]
│ 0 ┆ [3]
│ 1 ┆ [4, 6]
│ 0 ┆ [3, 5]
└─────┴───────────┘

Using `agg_list=True` would be more efficient. In this example, the input of
the function is a Series of type `List(Int64)`.

>>> df.group_by("a").agg(
... pl.col("b").map_batches(lambda x: x.list.max(), agg_list=True)
... pl.col("b").map_batches(
... lambda x: x.list.eval(pl.element() + 2), agg_list=True
... )
... ) # doctest: +IGNORE_RESULT
shape: (2, 2)
┌─────┬───────────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ list[i64] │
╞═════╪═══════════╡
│ 0 ┆ [3, 5] │
│ 1 ┆ [4, 6] │
└─────┴───────────┘

Here's an example of a function that returns a scalar, where we want it
to stay as a scalar:

>>> df = pl.DataFrame(
... {
... "a": [0, 1, 0, 1],
... "b": [1, 2, 3, 4],
... }
... )
>>> df.group_by("a").agg(
... pl.col("b").map_batches(lambda x: x.max(), returns_scalar=True)
... ) # doctest: +IGNORE_RESULT
shape: (2, 2)
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 0 ┆ 3 │
│ 1 ┆ 4 │
│ 0 ┆ 3 │
└─────┴─────┘

"""
if return_dtype is not None:
return_dtype = py_type_to_dtype(return_dtype)
Expand All @@ -4634,6 +4665,7 @@ def map_batches(
return_dtype,
agg_list,
is_elementwise,
returns_scalar,
)
)

Expand Down
11 changes: 9 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,7 @@ def __array_ufunc__(

# Only generalized ufuncs have a signature set:
is_generalized_ufunc = bool(ufunc.signature)

if is_generalized_ufunc:
# Generalized ufuncs will operate on the whole array, so
# missing data can corrupt the results.
Expand All @@ -1392,7 +1393,13 @@ def __array_ufunc__(
# output size.
assert ufunc.signature is not None # pacify MyPy
ufunc_input, ufunc_output = ufunc.signature.split("->")
allocate_output = ufunc_input == ufunc_output
if ufunc_output == "()":
# If the result a scalar, just let the function do its
# thing, no need for any song and dance involving
# allocation:
return ufunc(*args, dtype=dtype_char, **kwargs)
else:
allocate_output = ufunc_input == ufunc_output
else:
allocate_output = True

Expand All @@ -1409,6 +1416,7 @@ def __array_ufunc__(
lambda out: ufunc(*args, out=out, dtype=dtype_char, **kwargs),
allocate_output,
)

result = self._from_pyseries(series)
if is_generalized_ufunc:
# In this case we've disallowed passing in missing data, so no
Expand All @@ -1426,7 +1434,6 @@ def __array_ufunc__(
.select(F.when(validity_mask).then(F.col(self.name)))
.to_series(0)
)

else:
msg = (
"only `__call__` is implemented for numpy ufuncs on a Series, got "
Expand Down
12 changes: 10 additions & 2 deletions py-polars/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,15 +755,23 @@ impl PyExpr {
self.inner.clone().shrink_dtype().into()
}

#[pyo3(signature = (lambda, output_type, agg_list, is_elementwise))]
#[pyo3(signature = (lambda, output_type, agg_list, is_elementwise, returns_scalar))]
fn map_batches(
&self,
lambda: PyObject,
output_type: Option<Wrap<DataType>>,
agg_list: bool,
is_elementwise: bool,
returns_scalar: bool,
) -> Self {
map_single(self, lambda, output_type, agg_list, is_elementwise)
map_single(
self,
lambda,
output_type,
agg_list,
is_elementwise,
returns_scalar,
)
}

fn dot(&self, other: Self) -> Self {
Expand Down
4 changes: 3 additions & 1 deletion py-polars/src/map/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ pub fn map_single(
output_type: Option<Wrap<DataType>>,
agg_list: bool,
is_elementwise: bool,
returns_scalar: bool,
) -> PyExpr {
let output_type = output_type.map(|wrap| wrap.0);

let func = python_udf::PythonUdfExpression::new(lambda, output_type, is_elementwise);
let func =
python_udf::PythonUdfExpression::new(lambda, output_type, is_elementwise, returns_scalar);
pyexpr.inner.clone().map_python(func, agg_list).into()
}

Expand Down
38 changes: 38 additions & 0 deletions py-polars/tests/unit/interop/numpy/test_ufunc_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,44 @@ def test_grouped_ufunc() -> None:
df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1))


def test_generalized_ufunc_scalar() -> None:
numba = pytest.importorskip("numba")

@numba.guvectorize([(numba.int64[:], numba.int64[:])], "(n)->()") # type: ignore[misc]
def my_custom_sum(arr, result) -> None: # type: ignore[no-untyped-def]
total = 0
for value in arr:
total += value
result[0] = total

# Make type checkers happy:
custom_sum = cast(Callable[[object], object], my_custom_sum)

# Demonstrate NumPy as the canonical expected behavior:
assert custom_sum(np.array([10, 2, 3], dtype=np.int64)) == 15

# Direct call of the gufunc:
df = pl.DataFrame({"values": [10, 2, 3]})
assert custom_sum(df.get_column("values")) == 15

# Indirect call of the gufunc:
indirect = df.select(pl.col("values").map_batches(custom_sum, returns_scalar=True))
assert_frame_equal(indirect, pl.DataFrame({"values": 15}))
indirect = df.select(pl.col("values").map_batches(custom_sum, returns_scalar=False))
assert_frame_equal(indirect, pl.DataFrame({"values": [15]}))

# group_by()
df = pl.DataFrame({"labels": ["a", "b", "a", "b"], "values": [10, 2, 3, 30]})
indirect = (
df.group_by("labels")
.agg(pl.col("values").map_batches(custom_sum, returns_scalar=True))
.sort("labels")
)
assert_frame_equal(
indirect, pl.DataFrame({"labels": ["a", "b"], "values": [13, 32]})
)


def make_gufunc_mean() -> Callable[[pl.Series], pl.Series]:
numba = pytest.importorskip("numba")

Expand Down
14 changes: 6 additions & 8 deletions py-polars/tests/unit/operations/map/test_map_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def test_map_batches_group() -> None:
assert df.group_by("id").agg(pl.col("t").map_batches(lambda s: s.sum())).sort(
"id"
).to_dict(as_series=False) == {"id": [0, 1], "t": [[11], [35]]}
# If returns_scalar is True, the result won't be wrapped in a list:
assert df.group_by("id").agg(
pl.col("t").map_batches(lambda s: s.sum(), returns_scalar=True)
).sort("id").to_dict(as_series=False) == {"id": [0, 1], "t": [11, 35]}


def test_map_deprecated() -> None:
Expand All @@ -82,16 +86,10 @@ def test_map_deprecated() -> None:
def test_ufunc_args() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]})
result = df.select(
z=np.add( # type: ignore[call-overload]
pl.col("a"), pl.col("b")
)
z=np.add(pl.col("a"), pl.col("b")) # type: ignore[call-overload]
)
expected = pl.DataFrame({"z": [3, 6, 9]})
assert_frame_equal(result, expected)
result = df.select(
z=np.add( # type: ignore[call-overload]
2, pl.col("a")
)
)
result = df.select(z=np.add(2, pl.col("a"))) # type: ignore[call-overload]
expected = pl.DataFrame({"z": [3, 4, 5]})
assert_frame_equal(result, expected)