diff --git a/docs/user-guide/expressions/plugins.md b/docs/user-guide/expressions/plugins.md
index 1f075650e975..60c5aedfb7af 100644
--- a/docs/user-guide/expressions/plugins.md
+++ b/docs/user-guide/expressions/plugins.md
@@ -92,50 +92,24 @@ expression in batches. Whereas for other operations this would not be allowed, t
```python
# expression_lib/__init__.py
+from pathlib import Path
+from typing import TYPE_CHECKING
+
import polars as pl
+from polars.plugins import register_plugin_function
from polars.type_aliases import IntoExpr
-from polars.utils.udfs import _get_shared_lib_location
-
-from expression_lib.utils import parse_into_expr
-# Boilerplate needed to inform Polars of the location of binary wheel.
-lib = _get_shared_lib_location(__file__)
-def pig_latinnify(expr: IntoExpr, capitalize: bool = False) -> pl.Expr:
- expr = parse_into_expr(expr)
- return expr.register_plugin(
- lib=lib,
- symbol="pig_latinnify",
+def pig_latinnify(expr: IntoExpr) -> pl.Expr:
+ """Pig-latinnify expression."""
+ return register_plugin_function(
+ plugin_path=Path(__file__).parent,
+ function_name="pig_latinnify",
+ args=expr,
is_elementwise=True,
)
```
-```python
-# expression_lib/utils.py
-import polars as pl
-
-from polars.type_aliases import IntoExpr, PolarsDataType
-
-
-def parse_into_expr(
- expr: IntoExpr,
- *,
- str_as_lit: bool = False,
- list_as_lit: bool = True,
- dtype: PolarsDataType | None = None,
-) -> pl.Expr:
- """Parse a single input into an expression."""
- if isinstance(expr, pl.Expr):
- pass
- elif isinstance(expr, str) and not str_as_lit:
- expr = pl.col(expr)
- elif isinstance(expr, list) and not list_as_lit:
- expr = pl.lit(pl.Series(expr), dtype=dtype)
- else:
- expr = pl.lit(expr, dtype=dtype)
- return expr
-```
-
We can then compile this library in our environment by installing `maturin` and running `maturin develop --release`.
And that's it. Our expression is ready to use!
@@ -211,17 +185,16 @@ def append_args(
"""
This example shows how arguments other than `Series` can be used.
"""
- expr = parse_into_expr(expr)
- return expr.register_plugin(
- lib=lib,
- args=[],
+ return register_plugin_function(
+ plugin_path=Path(__file__).parent,
+ function_name="append_kwargs",
+ args=expr,
kwargs={
"float_arg": float_arg,
"integer_arg": integer_arg,
"string_arg": string_arg,
"boolean_arg": boolean_arg,
},
- symbol="append_kwargs",
is_elementwise=True,
)
```
diff --git a/py-polars/docs/source/reference/index.rst b/py-polars/docs/source/reference/index.rst
index e5bce90e43bb..70b48dc5399c 100644
--- a/py-polars/docs/source/reference/index.rst
+++ b/py-polars/docs/source/reference/index.rst
@@ -77,6 +77,7 @@ methods. All classes and functions exposed in the ``polars.*`` namespace are pub
:maxdepth: 2
api
+ plugins
.. grid::
diff --git a/py-polars/docs/source/reference/plugins.rst b/py-polars/docs/source/reference/plugins.rst
new file mode 100644
index 000000000000..e49d69f0a119
--- /dev/null
+++ b/py-polars/docs/source/reference/plugins.rst
@@ -0,0 +1,15 @@
+=======
+Plugins
+=======
+.. currentmodule:: polars
+
+Plugins allow for extending Polars' functionality. See the
+`user guide `_ for more information
+and resources.
+
+Available plugin utility functions are:
+
+.. automodule:: polars.plugins
+ :members:
+ :autosummary:
+ :autosummary-no-titles:
diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py
index dc958f9975fe..7a898d0ebc3f 100644
--- a/py-polars/polars/__init__.py
+++ b/py-polars/polars/__init__.py
@@ -17,7 +17,7 @@
__register_startup_deps()
-from polars import api
+from polars import api, exceptions, plugins, selectors
from polars._utils.polars_version import get_polars_version as _get_polars_version
# TODO: remove need for importing wrap utils at top level
@@ -225,6 +225,7 @@
__all__ = [
"api",
"exceptions",
+ "plugins",
# exceptions/errors
"ArrowError",
"ColumnNotFoundError",
diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py
index 9c856438623b..53f202b52572 100644
--- a/py-polars/polars/expr/expr.py
+++ b/py-polars/polars/expr/expr.py
@@ -9622,6 +9622,9 @@ def shift_and_fill(
"""
return self.shift(n, fill_value=fill_value)
+ @deprecate_function(
+ "Use `polars.plugins.register_plugin_function` instead.", version="0.20.16"
+ )
def register_plugin(
self,
*,
@@ -9635,20 +9638,26 @@ def register_plugin(
cast_to_supertypes: bool = False,
pass_name_to_apply: bool = False,
changes_length: bool = False,
- ) -> Self:
+ ) -> Expr:
"""
- Register a shared library as a plugin.
+ Register a plugin function.
- .. warning::
- This is highly unsafe as this will call the C function
- loaded by `lib::symbol`.
+ .. deprecated:: 0.20.16
+ Use :func:`polars.plugins.register_plugin_function` instead.
- The parameters you give dictate how polars will deal
- with the function. Make sure they are correct!
+ See the `user guide `_
+ for more information about plugins.
- .. note::
- This functionality is unstable and may change without it
- being considered breaking.
+ Warnings
+ --------
+ This method is deprecated. Use the new `polars.plugins.register_plugin_function`
+ function instead.
+
+ This is highly unsafe as this will call the C function loaded by
+ `lib::symbol`.
+
+ The parameters you set dictate how Polars will handle the function.
+ Make sure they are correct!
Parameters
----------
@@ -9677,31 +9686,24 @@ def register_plugin(
changes_length
For example a `unique` or a `slice`
"""
+ from polars.plugins import register_plugin_function
+
if args is None:
- args = []
+ args = [self]
else:
- args = [parse_as_expression(a) for a in args]
- if kwargs is None:
- serialized_kwargs = b""
- else:
- import pickle
-
- # Choose the highest protocol supported by https://docs.rs/serde-pickle/latest/serde_pickle/
- serialized_kwargs = pickle.dumps(kwargs, protocol=5)
+ args = [self, *list(args)]
- return self._from_pyexpr(
- self._pyexpr.register_plugin(
- lib,
- symbol,
- args,
- serialized_kwargs,
- is_elementwise,
- input_wildcard_expansion,
- returns_scalar,
- cast_to_supertypes,
- pass_name_to_apply,
- changes_length,
- )
+ return register_plugin_function(
+ plugin_path=lib,
+ function_name=symbol,
+ args=args,
+ kwargs=kwargs,
+ is_elementwise=is_elementwise,
+ changes_length=changes_length,
+ returns_scalar=returns_scalar,
+ cast_to_supertype=cast_to_supertypes,
+ input_wildcard_expansion=input_wildcard_expansion,
+ pass_name_to_apply=pass_name_to_apply,
)
@deprecate_renamed_function("register_plugin", version="0.19.12")
@@ -9716,7 +9718,7 @@ def _register_plugin(
input_wildcard_expansion: bool = False,
auto_explode: bool = False,
cast_to_supertypes: bool = False,
- ) -> Self:
+ ) -> Expr:
return self.register_plugin(
lib=lib,
symbol=symbol,
diff --git a/py-polars/polars/plugins.py b/py-polars/polars/plugins.py
new file mode 100644
index 000000000000..44a5f1a3b7a5
--- /dev/null
+++ b/py-polars/polars/plugins.py
@@ -0,0 +1,131 @@
+from __future__ import annotations
+
+import contextlib
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Iterable
+
+from polars._utils.parse_expr_input import parse_as_list_of_expressions
+from polars._utils.wrap import wrap_expr
+
+with contextlib.suppress(ImportError): # Module not available when building docs
+ import polars.polars as plr
+
+if TYPE_CHECKING:
+ from polars import Expr
+ from polars.type_aliases import IntoExpr
+
+__all__ = ["register_plugin_function"]
+
+
+def register_plugin_function(
+ *,
+ plugin_path: Path | str,
+ function_name: str,
+ args: IntoExpr | Iterable[IntoExpr],
+ kwargs: dict[str, Any] | None = None,
+ is_elementwise: bool = False,
+ changes_length: bool = False,
+ returns_scalar: bool = False,
+ cast_to_supertype: bool = False,
+ input_wildcard_expansion: bool = False,
+ pass_name_to_apply: bool = False,
+) -> Expr:
+ """
+ Register a plugin function.
+
+ See the `user guide `_
+ for more information about plugins.
+
+ Parameters
+ ----------
+ plugin_path
+ Path to the plugin package. Accepts either the file path to the dynamic library
+ file or the path to the directory containing it.
+ function_name
+ The name of the Rust function to register.
+ args
+ The arguments passed to this function. These get passed to the `input`
+ argument on the Rust side, and have to be expressions (or be convertible
+ to expressions).
+ kwargs
+ Non-expression arguments to the plugin function. These must be
+ JSON serializable.
+ is_elementwise
+ Indicate that the function operates on scalars only. This will potentially
+ trigger fast paths.
+ changes_length
+ Indicate that the function will change the length of the expression.
+ For example, a `unique` or `slice` operation.
+ returns_scalar
+ Automatically explode on unit length if the function ran as final aggregation.
+ This is the case for aggregations like `sum`, `min`, `covariance` etc.
+ cast_to_supertype
+ Cast the input expressions to their supertype.
+ input_wildcard_expansion
+ Expand wildcard expressions before executing the function.
+ pass_name_to_apply
+ If set to `True`, the `Series` passed to the function in a group-by operation
+ will ensure the name is set. This is an extra heap allocation per group.
+
+ Returns
+ -------
+ Expr
+
+ Warnings
+ --------
+ This is highly unsafe as this will call the C function loaded by
+ `plugin::function_name`.
+
+ The parameters you set dictate how Polars will handle the function.
+ Make sure they are correct!
+ """
+ pyexprs = parse_as_list_of_expressions(args)
+ serialized_kwargs = _serialize_kwargs(kwargs)
+ plugin_path = _resolve_plugin_path(plugin_path)
+
+ return wrap_expr(
+ plr.register_plugin_function(
+ plugin_path=str(plugin_path),
+ function_name=function_name,
+ args=pyexprs,
+ kwargs=serialized_kwargs,
+ is_elementwise=is_elementwise,
+ input_wildcard_expansion=input_wildcard_expansion,
+ returns_scalar=returns_scalar,
+ cast_to_supertype=cast_to_supertype,
+ pass_name_to_apply=pass_name_to_apply,
+ changes_length=changes_length,
+ )
+ )
+
+
+def _serialize_kwargs(kwargs: dict[str, Any] | None) -> bytes:
+ """Serialize the function's keyword arguments."""
+ if not kwargs:
+ return b""
+
+ import pickle
+
+ # Use the highest pickle protocol supported the serde-pickle crate:
+ # https://docs.rs/serde-pickle/latest/serde_pickle/
+ return pickle.dumps(kwargs, protocol=5)
+
+
+def _resolve_plugin_path(path: Path | str) -> Path:
+ """Get the file path of the dynamic library file."""
+ if not isinstance(path, Path):
+ path = Path(path)
+
+ if path.is_file():
+ return path.resolve()
+
+ for p in path.iterdir():
+ if _is_dynamic_lib(p):
+ return p.resolve()
+ else:
+ msg = f"no dynamic library found at path: {path}"
+ raise FileNotFoundError(msg)
+
+
+def _is_dynamic_lib(path: Path) -> bool:
+ return path.is_file() and path.suffix in (".so", ".dll", ".pyd")
diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py
index 398197ecef1a..f7ee74e0eece 100644
--- a/py-polars/polars/utils/udfs.py
+++ b/py-polars/polars/utils/udfs.py
@@ -3,10 +3,36 @@
import os
from typing import Any
+from polars._utils.deprecation import deprecate_function
+
__all__ = ["_get_shared_lib_location"]
+@deprecate_function(
+ "It will be removed in the next breaking release."
+ " The new `register_plugin_function` function has this functionality built in."
+ " Use `from polars.plugins import register_plugin_function` to import that function."
+ " Check the user guide for the currently-recommended way to register a plugin:"
+ " https://docs.pola.rs/user-guide/expressions/plugins",
+ version="0.20.16",
+)
def _get_shared_lib_location(main_file: Any) -> str:
+ """
+ Get the location of the dynamic library file.
+
+ .. deprecated:: 0.20.16
+ Use :func:`polars.plugins.register_plugin_function` instead.
+
+ Warnings
+ --------
+ This function is deprecated and will be removed in the next breaking release.
+ The new `polars.plugins.register_plugin_function` function has this
+ functionality built in. Use `from polars.plugins import register_plugin_function`
+ to import that function.
+
+ Check the user guide for the recommended way to register a plugin:
+ https://docs.pola.rs/user-guide/expressions/plugins
+ """
directory = os.path.dirname(main_file) # noqa: PTH120
return os.path.join( # noqa: PTH118
directory, next(filter(_is_shared_lib, os.listdir(directory)))
diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs
index 67d156e9e1c7..269236b7b2d4 100644
--- a/py-polars/src/expr/general.rs
+++ b/py-polars/src/expr/general.rs
@@ -863,54 +863,6 @@ impl PyExpr {
.into()
}
- #[cfg(feature = "ffi_plugin")]
- fn register_plugin(
- &self,
- lib: &str,
- symbol: &str,
- args: Vec,
- kwargs: Vec,
- is_elementwise: bool,
- input_wildcard_expansion: bool,
- returns_scalar: bool,
- cast_to_supertypes: bool,
- pass_name_to_apply: bool,
- changes_length: bool,
- ) -> PyResult {
- use polars_plan::prelude::*;
- let inner = self.inner.clone();
-
- let collect_groups = if is_elementwise {
- ApplyOptions::ElementWise
- } else {
- ApplyOptions::GroupWise
- };
- let mut input = Vec::with_capacity(args.len() + 1);
- input.push(inner);
- for a in args {
- input.push(a.inner)
- }
-
- Ok(Expr::Function {
- input,
- function: FunctionExpr::FfiPlugin {
- lib: Arc::from(lib),
- symbol: Arc::from(symbol),
- kwargs: Arc::from(kwargs),
- },
- options: FunctionOptions {
- collect_groups,
- input_wildcard_expansion,
- returns_scalar,
- cast_to_supertypes,
- pass_name_to_apply,
- changes_length,
- ..Default::default()
- },
- }
- .into())
- }
-
#[cfg(feature = "hist")]
#[pyo3(signature = (bins, bin_count, include_category, include_breakpoint))]
fn hist(
diff --git a/py-polars/src/functions/misc.rs b/py-polars/src/functions/misc.rs
index 593244618f03..8c4116e98832 100644
--- a/py-polars/src/functions/misc.rs
+++ b/py-polars/src/functions/misc.rs
@@ -1,10 +1,55 @@
+use std::sync::Arc;
+
+use polars_plan::prelude::*;
use pyo3::prelude::*;
use crate::conversion::Wrap;
+use crate::expr::ToExprs;
use crate::prelude::DataType;
+use crate::PyExpr;
#[pyfunction]
pub fn dtype_str_repr(dtype: Wrap) -> PyResult {
let dtype = dtype.0;
Ok(dtype.to_string())
}
+
+#[cfg(feature = "ffi_plugin")]
+#[pyfunction]
+pub fn register_plugin_function(
+ plugin_path: &str,
+ function_name: &str,
+ args: Vec,
+ kwargs: Vec,
+ is_elementwise: bool,
+ input_wildcard_expansion: bool,
+ returns_scalar: bool,
+ cast_to_supertype: bool,
+ pass_name_to_apply: bool,
+ changes_length: bool,
+) -> PyResult {
+ let collect_groups = if is_elementwise {
+ ApplyOptions::ElementWise
+ } else {
+ ApplyOptions::GroupWise
+ };
+
+ Ok(Expr::Function {
+ input: args.to_exprs(),
+ function: FunctionExpr::FfiPlugin {
+ lib: Arc::from(plugin_path),
+ symbol: Arc::from(function_name),
+ kwargs: Arc::from(kwargs),
+ },
+ options: FunctionOptions {
+ collect_groups,
+ input_wildcard_expansion,
+ returns_scalar,
+ cast_to_supertypes: cast_to_supertype,
+ pass_name_to_apply,
+ changes_length,
+ ..Default::default()
+ },
+ }
+ .into())
+}
diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs
index 1dcb20557e1a..cdd1725f6f9e 100644
--- a/py-polars/src/lib.rs
+++ b/py-polars/src/lib.rs
@@ -305,5 +305,9 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> {
pyo3_built!(py, build, "build", "time", "deps", "features", "host", "target", "git"),
)?;
+ // Plugins
+ m.add_wrapped(wrap_pyfunction!(functions::register_plugin_function))
+ .unwrap();
+
Ok(())
}
diff --git a/py-polars/tests/unit/test_plugins.py b/py-polars/tests/unit/test_plugins.py
new file mode 100644
index 000000000000..b983c9a9044f
--- /dev/null
+++ b/py-polars/tests/unit/test_plugins.py
@@ -0,0 +1,99 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+import polars as pl
+from polars.plugins import (
+ _is_dynamic_lib,
+ _resolve_plugin_path,
+ _serialize_kwargs,
+ register_plugin_function,
+)
+
+
+@pytest.mark.write_disk()
+def test_register_plugin_function_invalid_plugin_path(tmp_path: Path) -> None:
+ tmp_path.mkdir(exist_ok=True)
+ plugin_path = tmp_path / "lib.so"
+ plugin_path.touch()
+
+ expr = register_plugin_function(
+ plugin_path=plugin_path, function_name="hello", args=5
+ )
+
+ with pytest.raises(pl.ComputeError, match="error loading dynamic library"):
+ pl.select(expr)
+
+
+@pytest.mark.parametrize(
+ ("input", "expected"),
+ [
+ (None, b""),
+ ({}, b""),
+ (
+ {"hi": 0},
+ b"\x80\x05\x95\x0b\x00\x00\x00\x00\x00\x00\x00}\x94\x8c\x02hi\x94K\x00s.",
+ ),
+ ],
+)
+def test_serialize_kwargs(input: dict[str, Any] | None, expected: bytes) -> None:
+ assert _serialize_kwargs(input) == expected
+
+
+@pytest.mark.write_disk()
+def test_resolve_plugin_path(tmp_path: Path) -> None:
+ tmp_path.mkdir(exist_ok=True)
+
+ (tmp_path / "lib1.so").touch()
+ (tmp_path / "__init__.py").touch()
+
+ expected = tmp_path / "lib1.so"
+
+ result = _resolve_plugin_path(tmp_path)
+ assert result == expected
+ result = _resolve_plugin_path(tmp_path / "lib1.so")
+ assert result == expected
+ result = _resolve_plugin_path(str(tmp_path))
+ assert result == expected
+ result = _resolve_plugin_path(str(tmp_path / "lib1.so"))
+ assert result == expected
+
+
+@pytest.mark.write_disk()
+def test_resolve_plugin_path_raises(tmp_path: Path) -> None:
+ tmp_path.mkdir(exist_ok=True)
+ (tmp_path / "__init__.py").touch()
+
+ with pytest.raises(FileNotFoundError, match="no dynamic library found"):
+ _resolve_plugin_path(tmp_path)
+
+
+@pytest.mark.write_disk()
+@pytest.mark.parametrize(
+ ("path", "expected"),
+ [
+ (Path("lib.so"), True),
+ (Path("lib.pyd"), True),
+ (Path("lib.dll"), True),
+ (Path("lib.py"), False),
+ ],
+)
+def test_is_dynamic_lib(path: Path, expected: bool, tmp_path: Path) -> None:
+ tmp_path.mkdir(exist_ok=True)
+ full_path = tmp_path / path
+ full_path.touch()
+ assert _is_dynamic_lib(full_path) is expected
+
+
+@pytest.mark.write_disk()
+def test_is_dynamic_lib_dir(tmp_path: Path) -> None:
+ path = Path("lib.so")
+ full_path = tmp_path / path
+
+ full_path.mkdir(exist_ok=True)
+ (full_path / "hello.txt").touch()
+
+ assert _is_dynamic_lib(full_path) is False