Skip to content

Add de-/serializer func to PickleNode attributes. #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
nodes in v0.6.0.
- {pull}`662` adds the `.pixi` folder to be ignored by default during the collection.
- {pull}`671` enhances the documentation on complex repetitions. Closes {issue}`670`.
- {pull}`673` adds de-/serializer function attributes to the `PickleNode`. Closes
{issue}`669`.

## 0.5.2 - 2024-12-19

Expand Down
9 changes: 8 additions & 1 deletion docs/source/how_to_guides/writing_custom_nodes.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Here are some explanations.
signature is a hash and a unique identifier for the node. For most nodes it will be a
hash of the path or the name.

- The {func}`classmethod` {meth}`~pytask.PickleNode.from_path` is a convenient method to
- The classmethod {meth}`~pytask.PickleNode.from_path` is a convenient method to
instantiate the class.

- The method {meth}`~pytask.PickleNode.state` yields a value that signals the node's
Expand All @@ -129,6 +129,13 @@ Here are some explanations.
- {meth}`~pytask.PickleNode.save` is called when a task function returns and allows to
save the return values.

## Improvements

Usually, you would like your custom node to work with {class}`pathlib.Path` objects and
{class}`upath.UPath` objects allowing to work with remote filesystems. To simplify
getting the state of the node, you can use the {class}`pytask.get_state_of_path`
function.

## Conclusion

Nodes are an important in concept pytask. They allow to pytask to build a DAG and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
) -> None:
self.name = name
self.path = path
self.attributes = attributes or {}
self.attributes = attributes if attributes is not None else {}

@property
def signature(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
) -> None:
self.name = name
self.path = path
self.attributes = attributes or {}
self.attributes = attributes if attributes is not None else {}

@property
def signature(self) -> str:
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ test = [
"syrupy",
"aiohttp", # For HTTPPath tests.
"coiled",
"cloudpickle",
]
typing = ["mypy>=1.9.0,<1.11", "nbqa>=1.8.5"]

Expand All @@ -85,9 +86,7 @@ Tracker = "https://github.com/pytask-dev/pytask/issues"
pytask = "pytask:cli"

[tool.uv]
dev-dependencies = [
"tox-uv>=1.7.0", "pygraphviz;platform_system=='Linux'",
]
dev-dependencies = ["tox-uv>=1.7.0", "pygraphviz;platform_system=='Linux'"]

[build-system]
requires = ["hatchling", "hatch_vcs"]
Expand Down
33 changes: 27 additions & 6 deletions src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from attrs import define
from attrs import field
from typing_extensions import deprecated
from upath import UPath
from upath._stat import UPathStatResult

Expand All @@ -28,6 +29,9 @@
from _pytask.typing import no_default

if TYPE_CHECKING:
from io import BufferedReader
from io import BufferedWriter

from _pytask.mark import Mark
from _pytask.models import NodeInfo
from _pytask.tree_util import PyTree
Expand All @@ -40,6 +44,7 @@
"PythonNode",
"Task",
"TaskWithoutPath",
"get_state_of_path",
]


Expand Down Expand Up @@ -145,7 +150,7 @@ def signature(self) -> str:

def state(self) -> str | None:
"""Return the state of the node."""
return _get_state(self.path)
return get_state_of_path(self.path)

def execute(self, **kwargs: Any) -> Any:
"""Execute the task."""
Expand Down Expand Up @@ -188,7 +193,7 @@ def state(self) -> str | None:
The state is given by the modification timestamp.

"""
return _get_state(self.path)
return get_state_of_path(self.path)

def load(self, is_product: bool = False) -> Path: # noqa: ARG002
"""Load the value."""
Expand Down Expand Up @@ -310,12 +315,18 @@ class PickleNode(PPathNode):
The path to the file.
attributes: dict[Any, Any]
A dictionary to store additional information of the task.
serializer
A function to serialize the object. Defaults to :func:`pickle.dump`.
deserializer
A function to deserialize the object. Defaults to :func:`pickle.load`.

"""

path: Path
name: str = ""
attributes: dict[Any, Any] = field(factory=dict)
serializer: Callable[[Any, BufferedWriter], None] = field(default=pickle.dump)
deserializer: Callable[[BufferedReader], Any] = field(default=pickle.load)

@property
def signature(self) -> str:
Expand All @@ -332,17 +343,17 @@ def from_path(cls, path: Path) -> PickleNode:
return cls(name=path.as_posix(), path=path)

def state(self) -> str | None:
return _get_state(self.path)
return get_state_of_path(self.path)

def load(self, is_product: bool = False) -> Any:
if is_product:
return self
with self.path.open("rb") as f:
return pickle.load(f) # noqa: S301
return self.deserializer(f)

def save(self, value: Any) -> None:
with self.path.open("wb") as f:
pickle.dump(value, f)
self.serializer(value, f)


@define(kw_only=True)
Expand Down Expand Up @@ -387,7 +398,7 @@ def collect(self) -> list[Path]:
return list(self.root_dir.glob(self.pattern)) # type: ignore[union-attr]


def _get_state(path: Path) -> str | None:
def get_state_of_path(path: Path) -> str | None:
"""Get state of a path.

A simple function to handle local and remote files.
Expand All @@ -411,3 +422,13 @@ def _get_state(path: Path) -> str | None:
return stat.as_info().get("ETag", "0")
msg = "Unknown stat object."
raise NotImplementedError(msg)


@deprecated("Use 'pytask.get_state_of_path' instead.")
def _get_state(path: Path) -> str | None:
"""Get state of a path.

A simple function to handle local and remote files.

"""
return get_state_of_path(path)
3 changes: 2 additions & 1 deletion src/pytask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from _pytask.nodes import PickleNode
from _pytask.nodes import PythonNode
from _pytask.nodes import Task
from _pytask.nodes import TaskWithoutPath
from _pytask.nodes import TaskWithoutPath, get_state_of_path
from _pytask.outcomes import CollectionOutcome
from _pytask.outcomes import Exit
from _pytask.outcomes import ExitCode
Expand Down Expand Up @@ -146,6 +146,7 @@
"get_all_marks",
"get_marks",
"get_plugin_manager",
"get_state_of_path",
"has_mark",
"hash_value",
"hookimpl",
Expand Down
3 changes: 1 addition & 2 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,8 +985,7 @@ def test_download_file(runner, tmp_path):
from upath import UPath

url = UPath(
"https://gist.githubusercontent.com/tobiasraabe/64c24426d5398cac4b9d37b85ebfaf"
"7c/raw/50c61fa9a5aa0b7d3a7582c4c260b43dabfea720/gistfile1.txt"
"https://gist.githubusercontent.com/tobiasraabe/64c24426d5398cac4b9d37b85ebfaf7c/raw/50c61fa9a5aa0b7d3a7582c4c260b43dabfea720/gistfile1.txt"
)

def task_download_file(path: UPath = url) -> Annotated[str, Path("data.csv")]:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
from pathlib import Path

import cloudpickle
import pytest

from pytask import NodeInfo
Expand Down Expand Up @@ -126,3 +127,43 @@ def test_hash_of_pickle_node(tmp_path, value, exists, expected):
)
def test_comply_with_protocol(node, protocol, expected):
assert isinstance(node, protocol) is expected


@pytest.mark.unit
def test_custom_serializer_deserializer_pickle_node(tmp_path):
"""Test that PickleNode correctly uses cloudpickle for de-/serialization."""

# Define custom serializer and deserializer using cloudpickle
def custom_serializer(obj, file):
# Custom serialization logic that adds a wrapper around the data
cloudpickle.dump({"custom_prefix": obj}, file)

def custom_deserializer(file):
# Custom deserialization logic that unwraps the data
data = cloudpickle.load(file)
return data["custom_prefix"]

# Create test data and path
test_data = {"key": "value"}
path = tmp_path.joinpath("custom.pkl")

# Create PickleNode with custom serializer and deserializer
node = PickleNode(
name="test",
path=path,
serializer=custom_serializer,
deserializer=custom_deserializer,
)

# Test saving with custom serializer
node.save(test_data)

# Verify custom serialization was used by directly reading the file
with path.open("rb") as f:
raw_data = cloudpickle.load(f)
assert "custom_prefix" in raw_data
assert raw_data["custom_prefix"] == test_data

# Test loading with custom deserializer
loaded_data = node.load(is_product=False)
assert loaded_data == test_data