Skip to content

Commit

Permalink
add test for get_dynamic_lib_location
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 5, 2024
1 parent 74966bf commit fbbe41b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 23 deletions.
19 changes: 18 additions & 1 deletion py-polars/polars/_utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,4 +978,21 @@ def warn_on_inefficient_map(
)


__all__ = ["BytecodeParser", "warn_on_inefficient_map"]
def get_dynamic_lib_location(package_init_path: str | Path) -> str:
"""Get location of dynamic library file."""
if Path(package_init_path).is_file():
package_dir = Path(package_init_path).parent
else:
package_dir = Path(package_init_path)
for path in package_dir.iterdir():
if _is_shared_lib(path):
return str(path)
msg = f"no dynamic library found in {package_dir}"
raise FileNotFoundError(msg)


def _is_shared_lib(file: Path) -> bool:
return file.name.endswith((".so", ".dll", ".pyd"))


__all__ = ["BytecodeParser", "warn_on_inefficient_map", "get_dynamic_lib_location"]
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9638,7 +9638,7 @@ def register_plugin(
changes_length: bool = False,
) -> Self:
"""
Register a shared library as a plugin.
Register a dynamic library as a plugin.
.. warning::
This is highly unsafe as this will call the C function
Expand Down
25 changes: 5 additions & 20 deletions py-polars/polars/plugins.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,21 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

from polars._utils.parse_expr_input import parse_as_list_of_expressions
from polars._utils.udfs import get_dynamic_lib_location
from polars._utils.unstable import unstable
from polars._utils.wrap import wrap_expr

if TYPE_CHECKING:
from pathlib import Path

from polars import Expr
from polars.type_aliases import IntoExpr

__all__ = ["register_plugin"]


def _get_shared_lib_location(package_init_path: str | Path) -> str:
"""Get location of dynamic library file."""
if Path(package_init_path).is_file():
package_dir = Path(package_init_path).parent
else:
package_dir = Path(package_init_path)
for path in package_dir.iterdir():
if _is_shared_lib(path):
return str(path)
msg = f"No shared library found in {package_dir}"
raise FileNotFoundError(msg)


def _is_shared_lib(file: Path) -> bool:
return file.name.endswith((".so", ".dll", ".pyd"))


@unstable()
def register_plugin(
*args: IntoExpr | list[IntoExpr],
Expand All @@ -45,7 +30,7 @@ def register_plugin(
changes_length: bool = False,
) -> Expr:
"""
Register a shared library as a plugin.
Register a dynamic library as a plugin.
.. warning::
This is highly unsafe as this will call the C function
Expand Down Expand Up @@ -109,7 +94,7 @@ def register_plugin(
# Choose the highest protocol supported by https://docs.rs/serde-pickle/latest/serde_pickle/
serialized_kwargs = pickle.dumps(kwargs, protocol=5)

lib_location = _get_shared_lib_location(plugin_location)
lib_location = get_dynamic_lib_location(plugin_location)

return wrap_expr(
pyexprs[0].register_plugin(
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
def _get_shared_lib_location(main_file: Any) -> str:
issue_deprecation_warning(
"polars.utils will be made private in a future release. Please use "
"`from polars.plugins import get_shared_lib_location` instead.",
"`from polars.plugins import register_plugin` instead. "
"Note that its interface has changed - check the user guide "
"(https://docs.pola.rs/user-guide/expressions/plugins) "
"for the currently-recommended way to register a plugin.",
version="0.20.14",
)
directory = os.path.dirname(main_file) # noqa: PTH120
Expand Down
24 changes: 24 additions & 0 deletions py-polars/tests/unit/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
time_to_int,
timedelta_to_int,
)
from polars._utils.udfs import get_dynamic_lib_location
from polars._utils.various import (
_in_notebook,
is_bool_sequence,
Expand All @@ -26,6 +27,8 @@
from polars.io._utils import _looks_like_url

if TYPE_CHECKING:
from pathlib import Path

from zoneinfo import ZoneInfo

from polars.type_aliases import TimeUnit
Expand Down Expand Up @@ -310,3 +313,24 @@ def test_is_str_sequence_check(
)
def test_looks_like_url(url: str, result: bool) -> None:
assert _looks_like_url(url) == result


def test_get_dynamic_lib_location(tmpdir: Path) -> None:
(tmpdir / "lib1.so").write_text("", encoding="utf-8")
(tmpdir / "__init__.py").write_text("", encoding="utf-8")
result = get_dynamic_lib_location(tmpdir)
assert result == str(tmpdir / "lib1.so")
result = get_dynamic_lib_location(tmpdir / "__init__.py")
assert result == str(tmpdir / "lib1.so")
result = get_dynamic_lib_location(str(tmpdir))
assert result == str(tmpdir / "lib1.so")
result = get_dynamic_lib_location(str(tmpdir / "__init__.py"))
assert result == str(tmpdir / "lib1.so")


def test_get_dynamic_lib_location_raises(tmpdir: Path) -> None:
(tmpdir / "__init__.py").write_text("", encoding="utf-8")
with pytest.raises(FileNotFoundError, match="no dynamic library found"):
get_dynamic_lib_location(tmpdir)
with pytest.raises(FileNotFoundError, match="no dynamic library found"):
get_dynamic_lib_location(tmpdir / "__init__.py")

0 comments on commit fbbe41b

Please sign in to comment.