Skip to content

Commit

Permalink
Fix up some things
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Mar 15, 2024
1 parent a89d44c commit afbea38
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 111 deletions.
8 changes: 4 additions & 4 deletions docs/user-guide/expressions/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ from polars.type_aliases import IntoExpr
def pig_latinnify(expr: IntoExpr) -> pl.Expr:
"""Pig-latinnify expression."""
return register_plugin_function(
plugin_location=Path(__file__).parent,
plugin_path=Path(__file__).parent,
function_name="pig_latinnify",
inputs=expr,
args=expr,
is_elementwise=True,
)
```
Expand Down Expand Up @@ -186,9 +186,9 @@ def append_args(
This example shows how arguments other than `Series` can be used.
"""
return register_plugin_function(
plugin_location=Path(__file__).parent,
plugin_path=Path(__file__).parent,
function_name="append_kwargs",
inputs=expr,
args=expr,
kwargs={
"float_arg": float_arg,
"integer_arg": integer_arg,
Expand Down
20 changes: 10 additions & 10 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9642,19 +9642,19 @@ def register_plugin(
.. deprecated:: 0.20.16
Use :func:`polars.plugins.register_plugin_function` instead.
.. warning::
This functionality is considered **unstable**. It may be changed
at any point without it being considered a breaking change.
See the `user guide <https://docs.pola.rs/user-guide/expressions/plugins/>`_
for more information about plugins.
.. warning::
This is highly unsafe as this will call the C function
loaded by `lib::symbol`.
Warnings
--------
This method is deprecated. Use the new `polars.plugins.register_plugin_function`
function instead.
The parameters you give dictate how polars will deal
with the function. Make sure they are correct!
This is highly unsafe as this will call the C function loaded by
`lib::symbol`.
See the `user guide <https://docs.pola.rs/user-guide/expressions/plugins/>`_
for more information about plugins.
The parameters you set dictate how Polars will handle the function.
Make sure they are correct!
Parameters
----------
Expand Down
131 changes: 65 additions & 66 deletions py-polars/polars/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,124 +5,123 @@
from typing import TYPE_CHECKING, Any, Iterable

from polars._utils.parse_expr_input import parse_as_list_of_expressions
from polars._utils.unstable import unstable
from polars._utils.wrap import wrap_expr

with contextlib.suppress(ImportError): # Module not available when building docs
import polars.polars as plr

if TYPE_CHECKING:
from polars import Expr
from polars.type_aliases import IntoExpr

with contextlib.suppress(ImportError): # Module not available when building docs
import polars.polars as plr

__all__ = ["register_plugin_function"]


@unstable()
def register_plugin_function(
*,
plugin_location: Path | str,
plugin_path: Path | str,
function_name: str,
inputs: IntoExpr | Iterable[IntoExpr],
args: IntoExpr | Iterable[IntoExpr],
kwargs: dict[str, Any] | None = None,
is_elementwise: bool = False,
input_wildcard_expansion: bool = False,
changes_length: bool = False,
returns_scalar: bool = False,
cast_to_supertypes: bool = False,
cast_to_supertype: bool = False,
input_wildcard_expansion: bool = False,
pass_name_to_apply: bool = False,
changes_length: bool = False,
) -> Expr:
"""
Register a plugin function.
.. warning::
This functionality is considered **unstable**. It may be changed
at any point without it being considered a breaking change.
.. warning::
This is highly unsafe as this will call the C function
loaded by `lib::symbol`.
The parameters you set dictate how Polars will deal with the function.
Make sure they are correct!
See the `user guide <https://docs.pola.rs/user-guide/expressions/plugins/>`_
for more information about plugins.
Parameters
----------
plugin_location
Path to the package where plugin is located. This should either be the dynamic
library file, or the directory containing it.
plugin_path
Path to the plugin package. Accepts either the file path to the dynamic library
file or the path to the directory containing it.
function_name
Name of the Rust function to register.
inputs
Arguments passed to this function. These get passed to the ``inputs``
argument on the Rust side, and have to be of type Expression (or be
convertible to expressions).
The name of the Rust function to register.
args
The arguments passed to this function. These get passed to the `input`
argument on the Rust side, and have to be expressions (or be convertible
to expressions).
kwargs
Non-expression arguments. They must be JSON serializable.
Non-expression arguments to the plugin function. These must be
JSON serializable.
is_elementwise
If the function only operates on scalars, this will potentially trigger fast
paths.
input_wildcard_expansion
Expand expressions as input of this function.
Indicate that the function operates on scalars only. This will potentially
trigger fast paths.
changes_length
Indicate that the function will change the length of the expression.
For example, a `unique` or `slice` operation.
returns_scalar
Automatically explode on unit length if it ran as final aggregation.
this is the case for aggregations like `sum`, `min`, `covariance` etc.
cast_to_supertypes
Cast the input datatypes to their supertype.
Automatically explode on unit length if the function ran as final aggregation.
This is the case for aggregations like `sum`, `min`, `covariance` etc.
cast_to_supertype
Cast the input expressions to their supertype.
input_wildcard_expansion
Expand wildcard expressions before executing the function.
pass_name_to_apply
If set to `True`, the `Series` passed to the function in the group_by operation
If set to `True`, the `Series` passed to the function in a group-by operation
will ensure the name is set. This is an extra heap allocation per group.
changes_length
For example a `unique` or a `slice`
Returns
-------
Expr
"""
pyexprs = parse_as_list_of_expressions(inputs)
if not pyexprs:
msg = "`inputs` must be non-empty"
raise TypeError(msg)
if kwargs is None:
serialized_kwargs = b""
else:
import pickle
# Choose the highest protocol supported by https://docs.rs/serde-pickle/latest/serde_pickle/
serialized_kwargs = pickle.dumps(kwargs, protocol=5)
Warnings
--------
This is highly unsafe as this will call the C function loaded by
`plugin::function_name`.
lib_location = _get_dynamic_lib_location(plugin_location)
The parameters you set dictate how Polars will handle the function.
Make sure they are correct!
"""
pyexprs = parse_as_list_of_expressions(args)
serialized_kwargs = _serialize_kwargs(kwargs)
plugin_path = _resolve_plugin_path(plugin_path)

return wrap_expr(
plr.register_plugin_function(
str(lib_location),
function_name,
pyexprs,
serialized_kwargs,
is_elementwise,
input_wildcard_expansion,
returns_scalar,
cast_to_supertypes,
pass_name_to_apply,
changes_length,
plugin_path=str(plugin_path),
function_name=function_name,
args=pyexprs,
kwargs=serialized_kwargs,
is_elementwise=is_elementwise,
input_wildcard_expansion=input_wildcard_expansion,
returns_scalar=returns_scalar,
cast_to_supertype=cast_to_supertype,
pass_name_to_apply=pass_name_to_apply,
changes_length=changes_length,
)
)


def _get_dynamic_lib_location(path: Path | str) -> Path:
def _serialize_kwargs(kwargs: dict[str, Any] | None) -> bytes:
"""Serialize the function's keyword arguments."""
if not kwargs:
return b""

import pickle

# Use the highest pickle protocol supported the serde-pickle crate:
# https://docs.rs/serde-pickle/latest/serde_pickle/
return pickle.dumps(kwargs, protocol=5)


def _resolve_plugin_path(path: Path | str) -> Path:
"""Get the file path of the dynamic library file."""
if not isinstance(path, Path):
path = Path(path)

if path.is_file():
return path
return path.resolve()

for p in path.iterdir():
if _is_dynamic_lib(p):
return p
return p.resolve()
else:
msg = f"no dynamic library found at path: {path}"
raise FileNotFoundError(msg)
Expand Down
20 changes: 10 additions & 10 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import os
from typing import Any

from polars._utils.deprecation import issue_deprecation_warning

__all__ = ["_get_shared_lib_location"]


Expand All @@ -14,15 +12,17 @@ def _get_shared_lib_location(main_file: Any) -> str:
.. deprecated:: 0.20.16
Use :func:`polars.plugins.register_plugin_function` instead.
Warnings
--------
This function is deprecated and will be removed in the next breaking release.
The new `polars.plugins.register_plugin_function` function has this
functionality built in. Use `from polars.plugins import register_plugin_function`
to import that function.
Check the user guide for the recommended way to register a plugin:
https://docs.pola.rs/user-guide/expressions/plugins
"""
issue_deprecation_warning(
"`_get_shared_lib_location` is deprecated and will be removed in the next breaking release."
" The new `register_plugin_function` function has this functionality built in."
" Use `from polars.plugins import register_plugin_function` to import that function."
" Check the user guide for the currently-recommended way to register a plugin:"
" https://docs.pola.rs/user-guide/expressions/plugins",
version="0.20.16",
)
directory = os.path.dirname(main_file) # noqa: PTH120
return os.path.join( # noqa: PTH118
directory, next(filter(_is_shared_lib, os.listdir(directory)))
Expand Down
22 changes: 9 additions & 13 deletions py-polars/src/functions/misc.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::sync::Arc;

use polars_plan::prelude::*;
use pyo3::prelude::*;

use crate::conversion::Wrap;
use crate::expr::ToExprs;
use crate::prelude::DataType;
use crate::PyExpr;

Expand All @@ -15,41 +17,35 @@ pub fn dtype_str_repr(dtype: Wrap<DataType>) -> PyResult<String> {
#[cfg(feature = "ffi_plugin")]
#[pyfunction]
pub fn register_plugin_function(
lib: &str,
symbol: &str,
plugin_path: &str,
function_name: &str,
args: Vec<PyExpr>,
kwargs: Vec<u8>,
is_elementwise: bool,
input_wildcard_expansion: bool,
returns_scalar: bool,
cast_to_supertypes: bool,
cast_to_supertype: bool,
pass_name_to_apply: bool,
changes_length: bool,
) -> PyResult<PyExpr> {
use polars_plan::prelude::*;

let collect_groups = if is_elementwise {
ApplyOptions::ElementWise
} else {
ApplyOptions::GroupWise
};
let mut input = Vec::with_capacity(args.len());
for a in args {
input.push(a.inner)
}

Ok(Expr::Function {
input,
input: args.to_exprs(),
function: FunctionExpr::FfiPlugin {
lib: Arc::from(lib),
symbol: Arc::from(symbol),
lib: Arc::from(plugin_path),
symbol: Arc::from(function_name),
kwargs: Arc::from(kwargs),
},
options: FunctionOptions {
collect_groups,
input_wildcard_expansion,
returns_scalar,
cast_to_supertypes,
cast_to_supertypes: cast_to_supertype,
pass_name_to_apply,
changes_length,
..Default::default()
Expand Down
Loading

0 comments on commit afbea38

Please sign in to comment.