Skip to content

Commit 17accc9

Browse files
committed
Add de-/serializer func to PickleNode attributes.
1 parent 59eadcf commit 17accc9

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

src/_pytask/nodes.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from attrs import define
1616
from attrs import field
17+
from typing_extensions import deprecated
1718
from upath import UPath
1819
from upath._stat import UPathStatResult
1920

@@ -28,6 +29,9 @@
2829
from _pytask.typing import no_default
2930

3031
if TYPE_CHECKING:
32+
from io import BufferedReader
33+
from io import BufferedWriter
34+
3135
from _pytask.mark import Mark
3236
from _pytask.models import NodeInfo
3337
from _pytask.tree_util import PyTree
@@ -40,6 +44,7 @@
4044
"PythonNode",
4145
"Task",
4246
"TaskWithoutPath",
47+
"get_state_of_path",
4348
]
4449

4550

@@ -145,7 +150,7 @@ def signature(self) -> str:
145150

146151
def state(self) -> str | None:
147152
"""Return the state of the node."""
148-
return _get_state(self.path)
153+
return get_state_of_path(self.path)
149154

150155
def execute(self, **kwargs: Any) -> Any:
151156
"""Execute the task."""
@@ -188,7 +193,7 @@ def state(self) -> str | None:
188193
The state is given by the modification timestamp.
189194
190195
"""
191-
return _get_state(self.path)
196+
return get_state_of_path(self.path)
192197

193198
def load(self, is_product: bool = False) -> Path: # noqa: ARG002
194199
"""Load the value."""
@@ -316,6 +321,8 @@ class PickleNode(PPathNode):
316321
path: Path
317322
name: str = ""
318323
attributes: dict[Any, Any] = field(factory=dict)
324+
serializer: Callable[[Any, BufferedWriter], None] = field(default=pickle.dump)
325+
deserializer: Callable[[BufferedReader], Any] = field(default=pickle.load)
319326

320327
@property
321328
def signature(self) -> str:
@@ -332,17 +339,17 @@ def from_path(cls, path: Path) -> PickleNode:
332339
return cls(name=path.as_posix(), path=path)
333340

334341
def state(self) -> str | None:
335-
return _get_state(self.path)
342+
return get_state_of_path(self.path)
336343

337344
def load(self, is_product: bool = False) -> Any:
338345
if is_product:
339346
return self
340347
with self.path.open("rb") as f:
341-
return pickle.load(f) # noqa: S301
348+
return self.deserializer(f)
342349

343350
def save(self, value: Any) -> None:
344351
with self.path.open("wb") as f:
345-
pickle.dump(value, f)
352+
self.serializer(value, f)
346353

347354

348355
@define(kw_only=True)
@@ -387,7 +394,7 @@ def collect(self) -> list[Path]:
387394
return list(self.root_dir.glob(self.pattern)) # type: ignore[union-attr]
388395

389396

390-
def _get_state(path: Path) -> str | None:
397+
def get_state_of_path(path: Path) -> str | None:
391398
"""Get state of a path.
392399
393400
A simple function to handle local and remote files.
@@ -411,3 +418,13 @@ def _get_state(path: Path) -> str | None:
411418
return stat.as_info().get("ETag", "0")
412419
msg = "Unknown stat object."
413420
raise NotImplementedError(msg)
421+
422+
423+
@deprecated("Use 'pytask.get_state_of_path' instead.")
424+
def _get_state(path: Path) -> str | None:
425+
"""Get state of a path.
426+
427+
A simple function to handle local and remote files.
428+
429+
"""
430+
return get_state_of_path(path)

src/pytask/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from _pytask.nodes import PickleNode
5454
from _pytask.nodes import PythonNode
5555
from _pytask.nodes import Task
56-
from _pytask.nodes import TaskWithoutPath
56+
from _pytask.nodes import TaskWithoutPath, get_state_of_path
5757
from _pytask.outcomes import CollectionOutcome
5858
from _pytask.outcomes import Exit
5959
from _pytask.outcomes import ExitCode
@@ -146,6 +146,7 @@
146146
"get_all_marks",
147147
"get_marks",
148148
"get_plugin_manager",
149+
"get_state_of_path",
149150
"has_mark",
150151
"hash_value",
151152
"hookimpl",

0 commit comments

Comments
 (0)