diff --git a/docs/source/changes.md b/docs/source/changes.md index e92eaa96..a8d93f32 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -13,6 +13,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and nodes in v0.6.0. - {pull}`662` adds the `.pixi` folder to be ignored by default during the collection. - {pull}`671` enhances the documentation on complex repetitions. Closes {issue}`670`. +- {pull}`673` adds de-/serializer function attributes to the `PickleNode`. Closes + {issue}`669`. ## 0.5.2 - 2024-12-19 diff --git a/docs/source/how_to_guides/writing_custom_nodes.md b/docs/source/how_to_guides/writing_custom_nodes.md index 2204c0c9..68bf0f6a 100644 --- a/docs/source/how_to_guides/writing_custom_nodes.md +++ b/docs/source/how_to_guides/writing_custom_nodes.md @@ -111,7 +111,7 @@ Here are some explanations. signature is a hash and a unique identifier for the node. For most nodes it will be a hash of the path or the name. -- The {func}`classmethod` {meth}`~pytask.PickleNode.from_path` is a convenient method to +- The classmethod {meth}`~pytask.PickleNode.from_path` is a convenient method to instantiate the class. - The method {meth}`~pytask.PickleNode.state` yields a value that signals the node's @@ -129,6 +129,13 @@ Here are some explanations. - {meth}`~pytask.PickleNode.save` is called when a task function returns and allows to save the return values. +## Improvements + +Usually, you would like your custom node to work with {class}`pathlib.Path` objects and +{class}`upath.UPath` objects allowing to work with remote filesystems. To simplify +getting the state of the node, you can use the {class}`pytask.get_state_of_path` +function. + ## Conclusion Nodes are an important in concept pytask. They allow to pytask to build a DAG and diff --git a/docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py b/docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py index e4d00b2e..87e5f6f0 100644 --- a/docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py +++ b/docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py @@ -28,7 +28,7 @@ def __init__( ) -> None: self.name = name self.path = path - self.attributes = attributes or {} + self.attributes = attributes if attributes is not None else {} @property def signature(self) -> str: diff --git a/docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py b/docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py index d6499a64..583307d1 100644 --- a/docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py +++ b/docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py @@ -29,7 +29,7 @@ def __init__( ) -> None: self.name = name self.path = path - self.attributes = attributes or {} + self.attributes = attributes if attributes is not None else {} @property def signature(self) -> str: diff --git a/pyproject.toml b/pyproject.toml index 0a3d83c8..41a3d5fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ test = [ "syrupy", "aiohttp", # For HTTPPath tests. "coiled", + "cloudpickle", ] typing = ["mypy>=1.9.0,<1.11", "nbqa>=1.8.5"] @@ -85,9 +86,7 @@ Tracker = "https://github.com/pytask-dev/pytask/issues" pytask = "pytask:cli" [tool.uv] -dev-dependencies = [ - "tox-uv>=1.7.0", "pygraphviz;platform_system=='Linux'", -] +dev-dependencies = ["tox-uv>=1.7.0", "pygraphviz;platform_system=='Linux'"] [build-system] requires = ["hatchling", "hatch_vcs"] diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index 75b9a0ee..999e1e9d 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -14,6 +14,7 @@ from attrs import define from attrs import field +from typing_extensions import deprecated from upath import UPath from upath._stat import UPathStatResult @@ -28,6 +29,9 @@ from _pytask.typing import no_default if TYPE_CHECKING: + from io import BufferedReader + from io import BufferedWriter + from _pytask.mark import Mark from _pytask.models import NodeInfo from _pytask.tree_util import PyTree @@ -40,6 +44,7 @@ "PythonNode", "Task", "TaskWithoutPath", + "get_state_of_path", ] @@ -145,7 +150,7 @@ def signature(self) -> str: def state(self) -> str | None: """Return the state of the node.""" - return _get_state(self.path) + return get_state_of_path(self.path) def execute(self, **kwargs: Any) -> Any: """Execute the task.""" @@ -188,7 +193,7 @@ def state(self) -> str | None: The state is given by the modification timestamp. """ - return _get_state(self.path) + return get_state_of_path(self.path) def load(self, is_product: bool = False) -> Path: # noqa: ARG002 """Load the value.""" @@ -310,12 +315,18 @@ class PickleNode(PPathNode): The path to the file. attributes: dict[Any, Any] A dictionary to store additional information of the task. + serializer + A function to serialize the object. Defaults to :func:`pickle.dump`. + deserializer + A function to deserialize the object. Defaults to :func:`pickle.load`. """ path: Path name: str = "" attributes: dict[Any, Any] = field(factory=dict) + serializer: Callable[[Any, BufferedWriter], None] = field(default=pickle.dump) + deserializer: Callable[[BufferedReader], Any] = field(default=pickle.load) @property def signature(self) -> str: @@ -332,17 +343,17 @@ def from_path(cls, path: Path) -> PickleNode: return cls(name=path.as_posix(), path=path) def state(self) -> str | None: - return _get_state(self.path) + return get_state_of_path(self.path) def load(self, is_product: bool = False) -> Any: if is_product: return self with self.path.open("rb") as f: - return pickle.load(f) # noqa: S301 + return self.deserializer(f) def save(self, value: Any) -> None: with self.path.open("wb") as f: - pickle.dump(value, f) + self.serializer(value, f) @define(kw_only=True) @@ -387,7 +398,7 @@ def collect(self) -> list[Path]: return list(self.root_dir.glob(self.pattern)) # type: ignore[union-attr] -def _get_state(path: Path) -> str | None: +def get_state_of_path(path: Path) -> str | None: """Get state of a path. A simple function to handle local and remote files. @@ -411,3 +422,13 @@ def _get_state(path: Path) -> str | None: return stat.as_info().get("ETag", "0") msg = "Unknown stat object." raise NotImplementedError(msg) + + +@deprecated("Use 'pytask.get_state_of_path' instead.") +def _get_state(path: Path) -> str | None: + """Get state of a path. + + A simple function to handle local and remote files. + + """ + return get_state_of_path(path) diff --git a/src/pytask/__init__.py b/src/pytask/__init__.py index 4caec161..a71d6646 100644 --- a/src/pytask/__init__.py +++ b/src/pytask/__init__.py @@ -53,7 +53,7 @@ from _pytask.nodes import PickleNode from _pytask.nodes import PythonNode from _pytask.nodes import Task -from _pytask.nodes import TaskWithoutPath +from _pytask.nodes import TaskWithoutPath, get_state_of_path from _pytask.outcomes import CollectionOutcome from _pytask.outcomes import Exit from _pytask.outcomes import ExitCode @@ -146,6 +146,7 @@ "get_all_marks", "get_marks", "get_plugin_manager", + "get_state_of_path", "has_mark", "hash_value", "hookimpl", diff --git a/tests/test_execute.py b/tests/test_execute.py index c8cf4cc5..be3bd73e 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -985,8 +985,7 @@ def test_download_file(runner, tmp_path): from upath import UPath url = UPath( - "https://gist.githubusercontent.com/tobiasraabe/64c24426d5398cac4b9d37b85ebfaf" - "7c/raw/50c61fa9a5aa0b7d3a7582c4c260b43dabfea720/gistfile1.txt" + "https://gist.githubusercontent.com/tobiasraabe/64c24426d5398cac4b9d37b85ebfaf7c/raw/50c61fa9a5aa0b7d3a7582c4c260b43dabfea720/gistfile1.txt" ) def task_download_file(path: UPath = url) -> Annotated[str, Path("data.csv")]: diff --git a/tests/test_nodes.py b/tests/test_nodes.py index aa20a4e8..4a75b2c8 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -3,6 +3,7 @@ import pickle from pathlib import Path +import cloudpickle import pytest from pytask import NodeInfo @@ -126,3 +127,43 @@ def test_hash_of_pickle_node(tmp_path, value, exists, expected): ) def test_comply_with_protocol(node, protocol, expected): assert isinstance(node, protocol) is expected + + +@pytest.mark.unit +def test_custom_serializer_deserializer_pickle_node(tmp_path): + """Test that PickleNode correctly uses cloudpickle for de-/serialization.""" + + # Define custom serializer and deserializer using cloudpickle + def custom_serializer(obj, file): + # Custom serialization logic that adds a wrapper around the data + cloudpickle.dump({"custom_prefix": obj}, file) + + def custom_deserializer(file): + # Custom deserialization logic that unwraps the data + data = cloudpickle.load(file) + return data["custom_prefix"] + + # Create test data and path + test_data = {"key": "value"} + path = tmp_path.joinpath("custom.pkl") + + # Create PickleNode with custom serializer and deserializer + node = PickleNode( + name="test", + path=path, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + + # Test saving with custom serializer + node.save(test_data) + + # Verify custom serialization was used by directly reading the file + with path.open("rb") as f: + raw_data = cloudpickle.load(f) + assert "custom_prefix" in raw_data + assert raw_data["custom_prefix"] == test_data + + # Test loading with custom deserializer + loaded_data = node.load(is_product=False) + assert loaded_data == test_data