diff --git a/tests/files/test_dependencies.py b/tests/files/test_dependencies.py index bd09d68c..ac4ee56d 100644 --- a/tests/files/test_dependencies.py +++ b/tests/files/test_dependencies.py @@ -79,6 +79,7 @@ def test_deps(proj_path): _ = NodeC(deps=a.metrics) _ = NodeC(deps=a.plots) _ = NodeC(deps=a.outs_path) + # TODO: do we want to allow `x_path` as a `deps` or should it go into `deps_path`? _ = NodeC(deps=a.metrics_paths) _ = NodeC(deps=a.plots_path) diff --git a/tests/integration/test_outs_path_deps_path.py b/tests/integration/test_outs_path_deps_path.py new file mode 100644 index 00000000..01c6a7b5 --- /dev/null +++ b/tests/integration/test_outs_path_deps_path.py @@ -0,0 +1,18 @@ +from pathlib import Path + +import zntrack.examples + + +def test_outs_path_to_deps_path(proj_path): + with zntrack.Project() as proj: + a = zntrack.examples.WriteDVCOuts(params=10) + # assert a.outs == Path("nodes/WriteDVCOuts/output.txt") # uses znflow.resolve + b = zntrack.examples.ReadFile(path=a.outs) + # b = zntrack.examples.ReadFile(path=znflow.resolve(a.outs)) + # b = zntrack.examples.ReadFile(path=Path("nodes/WriteDVCOuts/output.txt")) # works + + proj.repro() + + assert a.outs == Path("nodes/WriteDVCOuts/output.txt") + assert b.path == Path("nodes/WriteDVCOuts/output.txt") + assert b.content == "10" diff --git a/zntrack/fields/x_path.py b/zntrack/fields/x_path.py index 97a73cc0..eac1f5a2 100644 --- a/zntrack/fields/x_path.py +++ b/zntrack/fields/x_path.py @@ -23,7 +23,26 @@ from zntrack.utils.node_wd import NWDReplaceHandler -def _paths_getter(self: Node, name: str): +def _paths_getter_input(self: Node, name: str): + """Resolve the paths for data the Node consumes.""" + if name in self.__dict__ and self.__dict__[name] is not ZNTRACK_LAZY_VALUE: + return self.__dict__[name] + try: + with self.state.fs.open(ZNTRACK_FILE_PATH) as f: + content = json.load(f)[self.name][name] + content = znjson.loads(json.dumps(content)) + + if self.state.tmp_path is not None: + loader = TempPathLoader() + loader(content, instance=self) + + return content + except FileNotFoundError: + return NOT_AVAILABLE + + +def _paths_getter_output(self: Node, name: str): + """Resolve the paths for data the Node produces.""" # TODO: if self._external_: try looking into # external/self.uuid/... # this works for everything except node-meta.json because that @@ -59,15 +78,14 @@ def outs_path( kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.OUTS_PATH kwargs["metadata"][ZNTRACK_CACHE] = cache kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent - kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter + kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_output return znfields.field(default=default, getter=plugin_getter, **kwargs) -def params_path(default=dataclasses.MISSING, *, cache: bool = True, **kwargs): +def params_path(default=dataclasses.MISSING, **kwargs): kwargs["metadata"] = kwargs.get("metadata", {}) kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.PARAMS_PATH - kwargs["metadata"][ZNTRACK_CACHE] = cache - kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter + kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_input return znfields.field(default=default, getter=plugin_getter, **kwargs) @@ -82,7 +100,7 @@ def plots_path( kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.PLOTS_PATH kwargs["metadata"][ZNTRACK_CACHE] = cache kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent - kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter + kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_output return znfields.field(default=default, getter=plugin_getter, **kwargs) @@ -99,13 +117,12 @@ def metrics_path( kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.METRICS_PATH kwargs["metadata"][ZNTRACK_CACHE] = cache kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent - kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter + kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_output return znfields.field(default=default, getter=plugin_getter, **kwargs) -def deps_path(default=dataclasses.MISSING, *, cache: bool = True, **kwargs): +def deps_path(default=dataclasses.MISSING, **kwargs): kwargs["metadata"] = kwargs.get("metadata", {}) kwargs["metadata"][ZNTRACK_OPTION] = ZnTrackOptionEnum.DEPS_PATH - kwargs["metadata"][ZNTRACK_CACHE] = cache - kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter + kwargs["metadata"][ZNTRACK_FIELD_LOAD] = _paths_getter_input return znfields.field(default=default, getter=plugin_getter, **kwargs) diff --git a/zntrack/node.py b/zntrack/node.py index a120d840..ed1e8487 100644 --- a/zntrack/node.py +++ b/zntrack/node.py @@ -16,7 +16,13 @@ from zntrack.state import NodeStatus from zntrack.utils.misc import get_plugins_from_env -from .config import NOT_AVAILABLE, ZNTRACK_LAZY_VALUE, NodeStatusEnum +from .config import ( + NOT_AVAILABLE, + ZNTRACK_LAZY_VALUE, + ZNTRACK_OPTION, + NodeStatusEnum, + ZnTrackOptionEnum, +) try: from typing import dataclass_transform @@ -78,6 +84,17 @@ def __post_init__(self): log.warning( "Node name should not contain '_'. This character is used for defining groups." ) + for field in dataclasses.fields(self): + # X_Path should be resolved instead of passing + # a connection. They are known at runtime. + if field.metadata.get(ZNTRACK_OPTION, None) in [ + ZnTrackOptionEnum.PARAMS_PATH, + ZnTrackOptionEnum.DEPS_PATH, + ZnTrackOptionEnum.OUTS_PATH, + ZnTrackOptionEnum.PLOTS_PATH, + ZnTrackOptionEnum.METRICS_PATH, + ]: + self._protected_.append(field.name) def _post_load_(self): """Called after `from_rev` is called."""