From 71a656357acb449967f96ba12afc225cf499af12 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Fri, 15 Mar 2024 15:17:31 +0000 Subject: [PATCH] feat: Make `register_plugin` a standalone function and include shared lib discovery (#14804) Co-authored-by: Stijn de Gooijer --- docs/user-guide/expressions/plugins.md | 55 +++----- py-polars/docs/source/reference/index.rst | 1 + py-polars/docs/source/reference/plugins.rst | 15 +++ py-polars/polars/__init__.py | 3 +- py-polars/polars/expr/expr.py | 68 +++++----- py-polars/polars/plugins.py | 131 ++++++++++++++++++++ py-polars/polars/utils/udfs.py | 26 ++++ py-polars/src/expr/general.rs | 48 ------- py-polars/src/functions/misc.rs | 45 +++++++ py-polars/src/lib.rs | 4 + py-polars/tests/unit/test_plugins.py | 99 +++++++++++++++ 11 files changed, 372 insertions(+), 123 deletions(-) create mode 100644 py-polars/docs/source/reference/plugins.rst create mode 100644 py-polars/polars/plugins.py create mode 100644 py-polars/tests/unit/test_plugins.py diff --git a/docs/user-guide/expressions/plugins.md b/docs/user-guide/expressions/plugins.md index 1f075650e975..60c5aedfb7af 100644 --- a/docs/user-guide/expressions/plugins.md +++ b/docs/user-guide/expressions/plugins.md @@ -92,50 +92,24 @@ expression in batches. Whereas for other operations this would not be allowed, t ```python # expression_lib/__init__.py +from pathlib import Path +from typing import TYPE_CHECKING + import polars as pl +from polars.plugins import register_plugin_function from polars.type_aliases import IntoExpr -from polars.utils.udfs import _get_shared_lib_location - -from expression_lib.utils import parse_into_expr -# Boilerplate needed to inform Polars of the location of binary wheel. -lib = _get_shared_lib_location(__file__) -def pig_latinnify(expr: IntoExpr, capitalize: bool = False) -> pl.Expr: - expr = parse_into_expr(expr) - return expr.register_plugin( - lib=lib, - symbol="pig_latinnify", +def pig_latinnify(expr: IntoExpr) -> pl.Expr: + """Pig-latinnify expression.""" + return register_plugin_function( + plugin_path=Path(__file__).parent, + function_name="pig_latinnify", + args=expr, is_elementwise=True, ) ``` -```python -# expression_lib/utils.py -import polars as pl - -from polars.type_aliases import IntoExpr, PolarsDataType - - -def parse_into_expr( - expr: IntoExpr, - *, - str_as_lit: bool = False, - list_as_lit: bool = True, - dtype: PolarsDataType | None = None, -) -> pl.Expr: - """Parse a single input into an expression.""" - if isinstance(expr, pl.Expr): - pass - elif isinstance(expr, str) and not str_as_lit: - expr = pl.col(expr) - elif isinstance(expr, list) and not list_as_lit: - expr = pl.lit(pl.Series(expr), dtype=dtype) - else: - expr = pl.lit(expr, dtype=dtype) - return expr -``` - We can then compile this library in our environment by installing `maturin` and running `maturin develop --release`. And that's it. Our expression is ready to use! @@ -211,17 +185,16 @@ def append_args( """ This example shows how arguments other than `Series` can be used. """ - expr = parse_into_expr(expr) - return expr.register_plugin( - lib=lib, - args=[], + return register_plugin_function( + plugin_path=Path(__file__).parent, + function_name="append_kwargs", + args=expr, kwargs={ "float_arg": float_arg, "integer_arg": integer_arg, "string_arg": string_arg, "boolean_arg": boolean_arg, }, - symbol="append_kwargs", is_elementwise=True, ) ``` diff --git a/py-polars/docs/source/reference/index.rst b/py-polars/docs/source/reference/index.rst index e5bce90e43bb..70b48dc5399c 100644 --- a/py-polars/docs/source/reference/index.rst +++ b/py-polars/docs/source/reference/index.rst @@ -77,6 +77,7 @@ methods. All classes and functions exposed in the ``polars.*`` namespace are pub :maxdepth: 2 api + plugins .. grid:: diff --git a/py-polars/docs/source/reference/plugins.rst b/py-polars/docs/source/reference/plugins.rst new file mode 100644 index 000000000000..e49d69f0a119 --- /dev/null +++ b/py-polars/docs/source/reference/plugins.rst @@ -0,0 +1,15 @@ +======= +Plugins +======= +.. currentmodule:: polars + +Plugins allow for extending Polars' functionality. See the +`user guide `_ for more information +and resources. + +Available plugin utility functions are: + +.. automodule:: polars.plugins + :members: + :autosummary: + :autosummary-no-titles: diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index dc958f9975fe..7a898d0ebc3f 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -17,7 +17,7 @@ __register_startup_deps() -from polars import api +from polars import api, exceptions, plugins, selectors from polars._utils.polars_version import get_polars_version as _get_polars_version # TODO: remove need for importing wrap utils at top level @@ -225,6 +225,7 @@ __all__ = [ "api", "exceptions", + "plugins", # exceptions/errors "ArrowError", "ColumnNotFoundError", diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 9c856438623b..53f202b52572 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -9622,6 +9622,9 @@ def shift_and_fill( """ return self.shift(n, fill_value=fill_value) + @deprecate_function( + "Use `polars.plugins.register_plugin_function` instead.", version="0.20.16" + ) def register_plugin( self, *, @@ -9635,20 +9638,26 @@ def register_plugin( cast_to_supertypes: bool = False, pass_name_to_apply: bool = False, changes_length: bool = False, - ) -> Self: + ) -> Expr: """ - Register a shared library as a plugin. + Register a plugin function. - .. warning:: - This is highly unsafe as this will call the C function - loaded by `lib::symbol`. + .. deprecated:: 0.20.16 + Use :func:`polars.plugins.register_plugin_function` instead. - The parameters you give dictate how polars will deal - with the function. Make sure they are correct! + See the `user guide `_ + for more information about plugins. - .. note:: - This functionality is unstable and may change without it - being considered breaking. + Warnings + -------- + This method is deprecated. Use the new `polars.plugins.register_plugin_function` + function instead. + + This is highly unsafe as this will call the C function loaded by + `lib::symbol`. + + The parameters you set dictate how Polars will handle the function. + Make sure they are correct! Parameters ---------- @@ -9677,31 +9686,24 @@ def register_plugin( changes_length For example a `unique` or a `slice` """ + from polars.plugins import register_plugin_function + if args is None: - args = [] + args = [self] else: - args = [parse_as_expression(a) for a in args] - if kwargs is None: - serialized_kwargs = b"" - else: - import pickle - - # Choose the highest protocol supported by https://docs.rs/serde-pickle/latest/serde_pickle/ - serialized_kwargs = pickle.dumps(kwargs, protocol=5) + args = [self, *list(args)] - return self._from_pyexpr( - self._pyexpr.register_plugin( - lib, - symbol, - args, - serialized_kwargs, - is_elementwise, - input_wildcard_expansion, - returns_scalar, - cast_to_supertypes, - pass_name_to_apply, - changes_length, - ) + return register_plugin_function( + plugin_path=lib, + function_name=symbol, + args=args, + kwargs=kwargs, + is_elementwise=is_elementwise, + changes_length=changes_length, + returns_scalar=returns_scalar, + cast_to_supertype=cast_to_supertypes, + input_wildcard_expansion=input_wildcard_expansion, + pass_name_to_apply=pass_name_to_apply, ) @deprecate_renamed_function("register_plugin", version="0.19.12") @@ -9716,7 +9718,7 @@ def _register_plugin( input_wildcard_expansion: bool = False, auto_explode: bool = False, cast_to_supertypes: bool = False, - ) -> Self: + ) -> Expr: return self.register_plugin( lib=lib, symbol=symbol, diff --git a/py-polars/polars/plugins.py b/py-polars/polars/plugins.py new file mode 100644 index 000000000000..44a5f1a3b7a5 --- /dev/null +++ b/py-polars/polars/plugins.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import contextlib +from pathlib import Path +from typing import TYPE_CHECKING, Any, Iterable + +from polars._utils.parse_expr_input import parse_as_list_of_expressions +from polars._utils.wrap import wrap_expr + +with contextlib.suppress(ImportError): # Module not available when building docs + import polars.polars as plr + +if TYPE_CHECKING: + from polars import Expr + from polars.type_aliases import IntoExpr + +__all__ = ["register_plugin_function"] + + +def register_plugin_function( + *, + plugin_path: Path | str, + function_name: str, + args: IntoExpr | Iterable[IntoExpr], + kwargs: dict[str, Any] | None = None, + is_elementwise: bool = False, + changes_length: bool = False, + returns_scalar: bool = False, + cast_to_supertype: bool = False, + input_wildcard_expansion: bool = False, + pass_name_to_apply: bool = False, +) -> Expr: + """ + Register a plugin function. + + See the `user guide `_ + for more information about plugins. + + Parameters + ---------- + plugin_path + Path to the plugin package. Accepts either the file path to the dynamic library + file or the path to the directory containing it. + function_name + The name of the Rust function to register. + args + The arguments passed to this function. These get passed to the `input` + argument on the Rust side, and have to be expressions (or be convertible + to expressions). + kwargs + Non-expression arguments to the plugin function. These must be + JSON serializable. + is_elementwise + Indicate that the function operates on scalars only. This will potentially + trigger fast paths. + changes_length + Indicate that the function will change the length of the expression. + For example, a `unique` or `slice` operation. + returns_scalar + Automatically explode on unit length if the function ran as final aggregation. + This is the case for aggregations like `sum`, `min`, `covariance` etc. + cast_to_supertype + Cast the input expressions to their supertype. + input_wildcard_expansion + Expand wildcard expressions before executing the function. + pass_name_to_apply + If set to `True`, the `Series` passed to the function in a group-by operation + will ensure the name is set. This is an extra heap allocation per group. + + Returns + ------- + Expr + + Warnings + -------- + This is highly unsafe as this will call the C function loaded by + `plugin::function_name`. + + The parameters you set dictate how Polars will handle the function. + Make sure they are correct! + """ + pyexprs = parse_as_list_of_expressions(args) + serialized_kwargs = _serialize_kwargs(kwargs) + plugin_path = _resolve_plugin_path(plugin_path) + + return wrap_expr( + plr.register_plugin_function( + plugin_path=str(plugin_path), + function_name=function_name, + args=pyexprs, + kwargs=serialized_kwargs, + is_elementwise=is_elementwise, + input_wildcard_expansion=input_wildcard_expansion, + returns_scalar=returns_scalar, + cast_to_supertype=cast_to_supertype, + pass_name_to_apply=pass_name_to_apply, + changes_length=changes_length, + ) + ) + + +def _serialize_kwargs(kwargs: dict[str, Any] | None) -> bytes: + """Serialize the function's keyword arguments.""" + if not kwargs: + return b"" + + import pickle + + # Use the highest pickle protocol supported the serde-pickle crate: + # https://docs.rs/serde-pickle/latest/serde_pickle/ + return pickle.dumps(kwargs, protocol=5) + + +def _resolve_plugin_path(path: Path | str) -> Path: + """Get the file path of the dynamic library file.""" + if not isinstance(path, Path): + path = Path(path) + + if path.is_file(): + return path.resolve() + + for p in path.iterdir(): + if _is_dynamic_lib(p): + return p.resolve() + else: + msg = f"no dynamic library found at path: {path}" + raise FileNotFoundError(msg) + + +def _is_dynamic_lib(path: Path) -> bool: + return path.is_file() and path.suffix in (".so", ".dll", ".pyd") diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 398197ecef1a..f7ee74e0eece 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -3,10 +3,36 @@ import os from typing import Any +from polars._utils.deprecation import deprecate_function + __all__ = ["_get_shared_lib_location"] +@deprecate_function( + "It will be removed in the next breaking release." + " The new `register_plugin_function` function has this functionality built in." + " Use `from polars.plugins import register_plugin_function` to import that function." + " Check the user guide for the currently-recommended way to register a plugin:" + " https://docs.pola.rs/user-guide/expressions/plugins", + version="0.20.16", +) def _get_shared_lib_location(main_file: Any) -> str: + """ + Get the location of the dynamic library file. + + .. deprecated:: 0.20.16 + Use :func:`polars.plugins.register_plugin_function` instead. + + Warnings + -------- + This function is deprecated and will be removed in the next breaking release. + The new `polars.plugins.register_plugin_function` function has this + functionality built in. Use `from polars.plugins import register_plugin_function` + to import that function. + + Check the user guide for the recommended way to register a plugin: + https://docs.pola.rs/user-guide/expressions/plugins + """ directory = os.path.dirname(main_file) # noqa: PTH120 return os.path.join( # noqa: PTH118 directory, next(filter(_is_shared_lib, os.listdir(directory))) diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 67d156e9e1c7..269236b7b2d4 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -863,54 +863,6 @@ impl PyExpr { .into() } - #[cfg(feature = "ffi_plugin")] - fn register_plugin( - &self, - lib: &str, - symbol: &str, - args: Vec, - kwargs: Vec, - is_elementwise: bool, - input_wildcard_expansion: bool, - returns_scalar: bool, - cast_to_supertypes: bool, - pass_name_to_apply: bool, - changes_length: bool, - ) -> PyResult { - use polars_plan::prelude::*; - let inner = self.inner.clone(); - - let collect_groups = if is_elementwise { - ApplyOptions::ElementWise - } else { - ApplyOptions::GroupWise - }; - let mut input = Vec::with_capacity(args.len() + 1); - input.push(inner); - for a in args { - input.push(a.inner) - } - - Ok(Expr::Function { - input, - function: FunctionExpr::FfiPlugin { - lib: Arc::from(lib), - symbol: Arc::from(symbol), - kwargs: Arc::from(kwargs), - }, - options: FunctionOptions { - collect_groups, - input_wildcard_expansion, - returns_scalar, - cast_to_supertypes, - pass_name_to_apply, - changes_length, - ..Default::default() - }, - } - .into()) - } - #[cfg(feature = "hist")] #[pyo3(signature = (bins, bin_count, include_category, include_breakpoint))] fn hist( diff --git a/py-polars/src/functions/misc.rs b/py-polars/src/functions/misc.rs index 593244618f03..8c4116e98832 100644 --- a/py-polars/src/functions/misc.rs +++ b/py-polars/src/functions/misc.rs @@ -1,10 +1,55 @@ +use std::sync::Arc; + +use polars_plan::prelude::*; use pyo3::prelude::*; use crate::conversion::Wrap; +use crate::expr::ToExprs; use crate::prelude::DataType; +use crate::PyExpr; #[pyfunction] pub fn dtype_str_repr(dtype: Wrap) -> PyResult { let dtype = dtype.0; Ok(dtype.to_string()) } + +#[cfg(feature = "ffi_plugin")] +#[pyfunction] +pub fn register_plugin_function( + plugin_path: &str, + function_name: &str, + args: Vec, + kwargs: Vec, + is_elementwise: bool, + input_wildcard_expansion: bool, + returns_scalar: bool, + cast_to_supertype: bool, + pass_name_to_apply: bool, + changes_length: bool, +) -> PyResult { + let collect_groups = if is_elementwise { + ApplyOptions::ElementWise + } else { + ApplyOptions::GroupWise + }; + + Ok(Expr::Function { + input: args.to_exprs(), + function: FunctionExpr::FfiPlugin { + lib: Arc::from(plugin_path), + symbol: Arc::from(function_name), + kwargs: Arc::from(kwargs), + }, + options: FunctionOptions { + collect_groups, + input_wildcard_expansion, + returns_scalar, + cast_to_supertypes: cast_to_supertype, + pass_name_to_apply, + changes_length, + ..Default::default() + }, + } + .into()) +} diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 1dcb20557e1a..cdd1725f6f9e 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -305,5 +305,9 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { pyo3_built!(py, build, "build", "time", "deps", "features", "host", "target", "git"), )?; + // Plugins + m.add_wrapped(wrap_pyfunction!(functions::register_plugin_function)) + .unwrap(); + Ok(()) } diff --git a/py-polars/tests/unit/test_plugins.py b/py-polars/tests/unit/test_plugins.py new file mode 100644 index 000000000000..b983c9a9044f --- /dev/null +++ b/py-polars/tests/unit/test_plugins.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +import polars as pl +from polars.plugins import ( + _is_dynamic_lib, + _resolve_plugin_path, + _serialize_kwargs, + register_plugin_function, +) + + +@pytest.mark.write_disk() +def test_register_plugin_function_invalid_plugin_path(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + plugin_path = tmp_path / "lib.so" + plugin_path.touch() + + expr = register_plugin_function( + plugin_path=plugin_path, function_name="hello", args=5 + ) + + with pytest.raises(pl.ComputeError, match="error loading dynamic library"): + pl.select(expr) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (None, b""), + ({}, b""), + ( + {"hi": 0}, + b"\x80\x05\x95\x0b\x00\x00\x00\x00\x00\x00\x00}\x94\x8c\x02hi\x94K\x00s.", + ), + ], +) +def test_serialize_kwargs(input: dict[str, Any] | None, expected: bytes) -> None: + assert _serialize_kwargs(input) == expected + + +@pytest.mark.write_disk() +def test_resolve_plugin_path(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + (tmp_path / "lib1.so").touch() + (tmp_path / "__init__.py").touch() + + expected = tmp_path / "lib1.so" + + result = _resolve_plugin_path(tmp_path) + assert result == expected + result = _resolve_plugin_path(tmp_path / "lib1.so") + assert result == expected + result = _resolve_plugin_path(str(tmp_path)) + assert result == expected + result = _resolve_plugin_path(str(tmp_path / "lib1.so")) + assert result == expected + + +@pytest.mark.write_disk() +def test_resolve_plugin_path_raises(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + (tmp_path / "__init__.py").touch() + + with pytest.raises(FileNotFoundError, match="no dynamic library found"): + _resolve_plugin_path(tmp_path) + + +@pytest.mark.write_disk() +@pytest.mark.parametrize( + ("path", "expected"), + [ + (Path("lib.so"), True), + (Path("lib.pyd"), True), + (Path("lib.dll"), True), + (Path("lib.py"), False), + ], +) +def test_is_dynamic_lib(path: Path, expected: bool, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + full_path = tmp_path / path + full_path.touch() + assert _is_dynamic_lib(full_path) is expected + + +@pytest.mark.write_disk() +def test_is_dynamic_lib_dir(tmp_path: Path) -> None: + path = Path("lib.so") + full_path = tmp_path / path + + full_path.mkdir(exist_ok=True) + (full_path / "hello.txt").touch() + + assert _is_dynamic_lib(full_path) is False