Skip to content

Commit

Permalink
Minor refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Mar 14, 2024
1 parent b9e2f82 commit 14d4d7a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 39 deletions.
38 changes: 20 additions & 18 deletions py-polars/polars/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

@unstable()
def register_plugin(
plugin_location: str | Path,
plugin_location: Path | str,
function_name: str,
inputs: IntoExpr | Iterable[IntoExpr],
kwargs: dict[str, Any] | None = None,
Expand All @@ -43,8 +43,8 @@ def register_plugin(
This is highly unsafe as this will call the C function
loaded by `lib::symbol`.
The parameters you give dictate how polars will deal
with the function. Make sure they are correct!
The parameters you set dictate how Polars will deal with the function.
Make sure they are correct!
See the `user guide <https://docs.pola.rs/user-guide/expressions/plugins/>`_
for more information about plugins.
Expand Down Expand Up @@ -98,7 +98,7 @@ def register_plugin(

return wrap_expr(
plr.register_plugin(
lib_location,
str(lib_location),
function_name,
pyexprs,
serialized_kwargs,
Expand All @@ -112,19 +112,21 @@ def register_plugin(
)


def _get_dynamic_lib_location(plugin_location: str | Path) -> str:
"""Get location of dynamic library file."""
if Path(plugin_location).is_file():
return str(plugin_location)
if not Path(plugin_location).is_dir():
msg = f"expected file or directory, got {plugin_location!r}"
raise TypeError(msg)
for path in Path(plugin_location).iterdir():
if _is_shared_lib(path):
return str(path)
msg = f"no dynamic library found in {plugin_location}"
raise FileNotFoundError(msg)
def _get_dynamic_lib_location(path: Path | str) -> Path:
"""Get the file path of the dynamic library file."""
if not isinstance(path, Path):
path = Path(path)

if path.is_file():
return path

for p in path.iterdir():
if _is_dynamic_lib(p):
return p
else:
msg = f"no dynamic library found at path: {path}"
raise FileNotFoundError(msg)


def _is_shared_lib(file: Path) -> bool:
return file.name.endswith((".so", ".dll", ".pyd"))
def _is_dynamic_lib(path: Path) -> bool:
return path.is_file() and path.suffix in (".so", ".dll", ".pyd")
18 changes: 12 additions & 6 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@


def _get_shared_lib_location(main_file: Any) -> str:
"""
Get the location of the dynamic library file.
.. deprecated:: 0.20.16
Use :func:`polars.plugins.register_plugin` instead.
"""
issue_deprecation_warning(
"_get_shared_lib_location is deprecated and will be removed in a future "
"version. Please use `from polars.plugins import register_plugin` instead. "
"Note that its interface has changed - check the user guide "
"(https://docs.pola.rs/user-guide/expressions/plugins) "
"for the currently-recommended way to register a plugin.",
version="0.20.15",
"`_get_shared_lib_location` is deprecated and will be removed in the next breaking release."
" The new `register_plugin` function has this functionality built in."
" Use `from polars.plugins import register_plugin` to import that function."
" Check the user guide for the currently-recommended way to register a plugin:"
" https://docs.pola.rs/user-guide/expressions/plugins",
version="0.20.16",
)
directory = os.path.dirname(main_file) # noqa: PTH120
return os.path.join( # noqa: PTH118
Expand Down
67 changes: 52 additions & 15 deletions py-polars/tests/unit/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,60 @@

import pytest

from polars.plugins import _get_dynamic_lib_location
from polars.plugins import _get_dynamic_lib_location, _is_dynamic_lib


def test_get_dynamic_lib_location(tmpdir: Path) -> None:
(tmpdir / "lib1.so").write_text("", encoding="utf-8")
(tmpdir / "__init__.py").write_text("", encoding="utf-8")
result = _get_dynamic_lib_location(tmpdir)
assert result == str(tmpdir / "lib1.so")
result = _get_dynamic_lib_location(tmpdir / "lib1.so")
assert result == str(tmpdir / "lib1.so")
result = _get_dynamic_lib_location(str(tmpdir))
assert result == str(tmpdir / "lib1.so")
result = _get_dynamic_lib_location(str(tmpdir / "lib1.so"))
assert result == str(tmpdir / "lib1.so")
@pytest.mark.write_disk()
def test_get_dynamic_lib_location(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

(tmp_path / "lib1.so").touch()
(tmp_path / "__init__.py").touch()

expected = tmp_path / "lib1.so"

result = _get_dynamic_lib_location(tmp_path)
assert result == expected
result = _get_dynamic_lib_location(tmp_path / "lib1.so")
assert result == expected
result = _get_dynamic_lib_location(str(tmp_path))
assert result == expected
result = _get_dynamic_lib_location(str(tmp_path / "lib1.so"))
assert result == expected


@pytest.mark.write_disk()
def test_get_dynamic_lib_location_raises(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
(tmp_path / "__init__.py").touch()

def test_get_dynamic_lib_location_raises(tmpdir: Path) -> None:
(tmpdir / "__init__.py").write_text("", encoding="utf-8")
with pytest.raises(FileNotFoundError, match="no dynamic library found"):
_get_dynamic_lib_location(tmpdir)
_get_dynamic_lib_location(tmp_path)


@pytest.mark.write_disk()
@pytest.mark.parametrize(
("path", "expected"),
[
(Path("lib.so"), True),
(Path("lib.pyd"), True),
(Path("lib.dll"), True),
(Path("lib.py"), False),
],
)
def test_is_dynamic_lib(path: Path, expected: bool, tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
full_path = tmp_path / path
full_path.touch()
assert _is_dynamic_lib(full_path) is expected


@pytest.mark.write_disk()
def test_is_dynamic_lib_dir(tmp_path: Path) -> None:
path = Path("lib.so")
full_path = tmp_path / path

full_path.mkdir(exist_ok=True)
(full_path / "hello.txt").touch()

assert _is_dynamic_lib(full_path) is False

0 comments on commit 14d4d7a

Please sign in to comment.