Skip to content

Commit

Permalink
feat: Make register_plugin a standalone function and include shared…
Browse files Browse the repository at this point in the history
… lib discovery (#14804)

Co-authored-by: Stijn de Gooijer <[email protected]>
  • Loading branch information
MarcoGorelli and stinodego authored Mar 15, 2024
1 parent 8c34bcc commit 71a6563
Show file tree
Hide file tree
Showing 11 changed files with 372 additions and 123 deletions.
55 changes: 14 additions & 41 deletions docs/user-guide/expressions/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,50 +92,24 @@ expression in batches. Whereas for other operations this would not be allowed, t

```python
# expression_lib/__init__.py
from pathlib import Path
from typing import TYPE_CHECKING

import polars as pl
from polars.plugins import register_plugin_function
from polars.type_aliases import IntoExpr
from polars.utils.udfs import _get_shared_lib_location

from expression_lib.utils import parse_into_expr

# Boilerplate needed to inform Polars of the location of binary wheel.
lib = _get_shared_lib_location(__file__)

def pig_latinnify(expr: IntoExpr, capitalize: bool = False) -> pl.Expr:
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
symbol="pig_latinnify",
def pig_latinnify(expr: IntoExpr) -> pl.Expr:
"""Pig-latinnify expression."""
return register_plugin_function(
plugin_path=Path(__file__).parent,
function_name="pig_latinnify",
args=expr,
is_elementwise=True,
)
```

```python
# expression_lib/utils.py
import polars as pl

from polars.type_aliases import IntoExpr, PolarsDataType


def parse_into_expr(
expr: IntoExpr,
*,
str_as_lit: bool = False,
list_as_lit: bool = True,
dtype: PolarsDataType | None = None,
) -> pl.Expr:
"""Parse a single input into an expression."""
if isinstance(expr, pl.Expr):
pass
elif isinstance(expr, str) and not str_as_lit:
expr = pl.col(expr)
elif isinstance(expr, list) and not list_as_lit:
expr = pl.lit(pl.Series(expr), dtype=dtype)
else:
expr = pl.lit(expr, dtype=dtype)
return expr
```

We can then compile this library in our environment by installing `maturin` and running `maturin develop --release`.

And that's it. Our expression is ready to use!
Expand Down Expand Up @@ -211,17 +185,16 @@ def append_args(
"""
This example shows how arguments other than `Series` can be used.
"""
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
args=[],
return register_plugin_function(
plugin_path=Path(__file__).parent,
function_name="append_kwargs",
args=expr,
kwargs={
"float_arg": float_arg,
"integer_arg": integer_arg,
"string_arg": string_arg,
"boolean_arg": boolean_arg,
},
symbol="append_kwargs",
is_elementwise=True,
)
```
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ methods. All classes and functions exposed in the ``polars.*`` namespace are pub
:maxdepth: 2

api
plugins


.. grid::
Expand Down
15 changes: 15 additions & 0 deletions py-polars/docs/source/reference/plugins.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
=======
Plugins
=======
.. currentmodule:: polars

Plugins allow for extending Polars' functionality. See the
`user guide <https://docs.pola.rs/user-guide/expressions/plugins/>`_ for more information
and resources.

Available plugin utility functions are:

.. automodule:: polars.plugins
:members:
:autosummary:
:autosummary-no-titles:
3 changes: 2 additions & 1 deletion py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

__register_startup_deps()

from polars import api
from polars import api, exceptions, plugins, selectors
from polars._utils.polars_version import get_polars_version as _get_polars_version

# TODO: remove need for importing wrap utils at top level
Expand Down Expand Up @@ -225,6 +225,7 @@
__all__ = [
"api",
"exceptions",
"plugins",
# exceptions/errors
"ArrowError",
"ColumnNotFoundError",
Expand Down
68 changes: 35 additions & 33 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9622,6 +9622,9 @@ def shift_and_fill(
"""
return self.shift(n, fill_value=fill_value)

@deprecate_function(
"Use `polars.plugins.register_plugin_function` instead.", version="0.20.16"
)
def register_plugin(
self,
*,
Expand All @@ -9635,20 +9638,26 @@ def register_plugin(
cast_to_supertypes: bool = False,
pass_name_to_apply: bool = False,
changes_length: bool = False,
) -> Self:
) -> Expr:
"""
Register a shared library as a plugin.
Register a plugin function.
.. warning::
This is highly unsafe as this will call the C function
loaded by `lib::symbol`.
.. deprecated:: 0.20.16
Use :func:`polars.plugins.register_plugin_function` instead.
The parameters you give dictate how polars will deal
with the function. Make sure they are correct!
See the `user guide <https://docs.pola.rs/user-guide/expressions/plugins/>`_
for more information about plugins.
.. note::
This functionality is unstable and may change without it
being considered breaking.
Warnings
--------
This method is deprecated. Use the new `polars.plugins.register_plugin_function`
function instead.
This is highly unsafe as this will call the C function loaded by
`lib::symbol`.
The parameters you set dictate how Polars will handle the function.
Make sure they are correct!
Parameters
----------
Expand Down Expand Up @@ -9677,31 +9686,24 @@ def register_plugin(
changes_length
For example a `unique` or a `slice`
"""
from polars.plugins import register_plugin_function

if args is None:
args = []
args = [self]
else:
args = [parse_as_expression(a) for a in args]
if kwargs is None:
serialized_kwargs = b""
else:
import pickle

# Choose the highest protocol supported by https://docs.rs/serde-pickle/latest/serde_pickle/
serialized_kwargs = pickle.dumps(kwargs, protocol=5)
args = [self, *list(args)]

return self._from_pyexpr(
self._pyexpr.register_plugin(
lib,
symbol,
args,
serialized_kwargs,
is_elementwise,
input_wildcard_expansion,
returns_scalar,
cast_to_supertypes,
pass_name_to_apply,
changes_length,
)
return register_plugin_function(
plugin_path=lib,
function_name=symbol,
args=args,
kwargs=kwargs,
is_elementwise=is_elementwise,
changes_length=changes_length,
returns_scalar=returns_scalar,
cast_to_supertype=cast_to_supertypes,
input_wildcard_expansion=input_wildcard_expansion,
pass_name_to_apply=pass_name_to_apply,
)

@deprecate_renamed_function("register_plugin", version="0.19.12")
Expand All @@ -9716,7 +9718,7 @@ def _register_plugin(
input_wildcard_expansion: bool = False,
auto_explode: bool = False,
cast_to_supertypes: bool = False,
) -> Self:
) -> Expr:
return self.register_plugin(
lib=lib,
symbol=symbol,
Expand Down
131 changes: 131 additions & 0 deletions py-polars/polars/plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from __future__ import annotations

import contextlib
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable

from polars._utils.parse_expr_input import parse_as_list_of_expressions
from polars._utils.wrap import wrap_expr

with contextlib.suppress(ImportError): # Module not available when building docs
import polars.polars as plr

if TYPE_CHECKING:
from polars import Expr
from polars.type_aliases import IntoExpr

__all__ = ["register_plugin_function"]


def register_plugin_function(
*,
plugin_path: Path | str,
function_name: str,
args: IntoExpr | Iterable[IntoExpr],
kwargs: dict[str, Any] | None = None,
is_elementwise: bool = False,
changes_length: bool = False,
returns_scalar: bool = False,
cast_to_supertype: bool = False,
input_wildcard_expansion: bool = False,
pass_name_to_apply: bool = False,
) -> Expr:
"""
Register a plugin function.
See the `user guide <https://docs.pola.rs/user-guide/expressions/plugins/>`_
for more information about plugins.
Parameters
----------
plugin_path
Path to the plugin package. Accepts either the file path to the dynamic library
file or the path to the directory containing it.
function_name
The name of the Rust function to register.
args
The arguments passed to this function. These get passed to the `input`
argument on the Rust side, and have to be expressions (or be convertible
to expressions).
kwargs
Non-expression arguments to the plugin function. These must be
JSON serializable.
is_elementwise
Indicate that the function operates on scalars only. This will potentially
trigger fast paths.
changes_length
Indicate that the function will change the length of the expression.
For example, a `unique` or `slice` operation.
returns_scalar
Automatically explode on unit length if the function ran as final aggregation.
This is the case for aggregations like `sum`, `min`, `covariance` etc.
cast_to_supertype
Cast the input expressions to their supertype.
input_wildcard_expansion
Expand wildcard expressions before executing the function.
pass_name_to_apply
If set to `True`, the `Series` passed to the function in a group-by operation
will ensure the name is set. This is an extra heap allocation per group.
Returns
-------
Expr
Warnings
--------
This is highly unsafe as this will call the C function loaded by
`plugin::function_name`.
The parameters you set dictate how Polars will handle the function.
Make sure they are correct!
"""
pyexprs = parse_as_list_of_expressions(args)
serialized_kwargs = _serialize_kwargs(kwargs)
plugin_path = _resolve_plugin_path(plugin_path)

return wrap_expr(
plr.register_plugin_function(
plugin_path=str(plugin_path),
function_name=function_name,
args=pyexprs,
kwargs=serialized_kwargs,
is_elementwise=is_elementwise,
input_wildcard_expansion=input_wildcard_expansion,
returns_scalar=returns_scalar,
cast_to_supertype=cast_to_supertype,
pass_name_to_apply=pass_name_to_apply,
changes_length=changes_length,
)
)


def _serialize_kwargs(kwargs: dict[str, Any] | None) -> bytes:
"""Serialize the function's keyword arguments."""
if not kwargs:
return b""

import pickle

# Use the highest pickle protocol supported the serde-pickle crate:
# https://docs.rs/serde-pickle/latest/serde_pickle/
return pickle.dumps(kwargs, protocol=5)


def _resolve_plugin_path(path: Path | str) -> Path:
"""Get the file path of the dynamic library file."""
if not isinstance(path, Path):
path = Path(path)

if path.is_file():
return path.resolve()

for p in path.iterdir():
if _is_dynamic_lib(p):
return p.resolve()
else:
msg = f"no dynamic library found at path: {path}"
raise FileNotFoundError(msg)


def _is_dynamic_lib(path: Path) -> bool:
return path.is_file() and path.suffix in (".so", ".dll", ".pyd")
26 changes: 26 additions & 0 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,36 @@
import os
from typing import Any

from polars._utils.deprecation import deprecate_function

__all__ = ["_get_shared_lib_location"]


@deprecate_function(
"It will be removed in the next breaking release."
" The new `register_plugin_function` function has this functionality built in."
" Use `from polars.plugins import register_plugin_function` to import that function."
" Check the user guide for the currently-recommended way to register a plugin:"
" https://docs.pola.rs/user-guide/expressions/plugins",
version="0.20.16",
)
def _get_shared_lib_location(main_file: Any) -> str:
"""
Get the location of the dynamic library file.
.. deprecated:: 0.20.16
Use :func:`polars.plugins.register_plugin_function` instead.
Warnings
--------
This function is deprecated and will be removed in the next breaking release.
The new `polars.plugins.register_plugin_function` function has this
functionality built in. Use `from polars.plugins import register_plugin_function`
to import that function.
Check the user guide for the recommended way to register a plugin:
https://docs.pola.rs/user-guide/expressions/plugins
"""
directory = os.path.dirname(main_file) # noqa: PTH120
return os.path.join( # noqa: PTH118
directory, next(filter(_is_shared_lib, os.listdir(directory)))
Expand Down
Loading

0 comments on commit 71a6563

Please sign in to comment.