Skip to content

Commit

Permalink
improve nwd and name loading (#893)
Browse files Browse the repository at this point in the history
* improve nwd and name loading

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bugfix

* fix unit tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove print statements

* bugfix for external nodes

* update required znflow version

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove old code, remove comments

* move function to utils

* access NWD directly for better performance

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use project attribute to keep track of NWD for fast access

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix `attr_name`

* remove znflow source

* update lock

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Feb 20, 2025
1 parent 0be558e commit 703d14a
Show file tree
Hide file tree
Showing 7 changed files with 452 additions and 98 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"tqdm>=4.67.1",
"typer>=0.15.1",
"znfields>=0.1.2",
"znflow>=0.2.4",
"znflow>=0.2.5",
"znjson>=0.2.6",
]

Expand All @@ -33,6 +33,7 @@ Discord = "https://discord.gg/7ncfwhsnm4"
dev = [
"dvc-s3>=3.2.0",
"h5py>=3.12.1",
"ipykernel>=6.29.5",
"mlflow>=2.20.0",
"pre-commit>=4.1.0",
"pytest>=8.3.4",
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_from_rev.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ def test_two_nodes_connect_external(proj_path):
remote="https://github.com/PythonFZ/zntrack-examples",
rev="de82dc7104ac3",
)
assert node_a.name == "NumericOuts"
assert node_a.__dict__["nwd"].as_posix() == "nodes/NumericOuts"

with zntrack.Project() as project:
assert node_a.name == "NumericOuts"
node1 = zntrack.examples.AddOne(number=node_a.outs)
node2 = zntrack.examples.AddOne(number=node_a.outs)

Expand Down
345 changes: 341 additions & 4 deletions uv.lock

Large diffs are not rendered by default.

62 changes: 46 additions & 16 deletions zntrack/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from zntrack.group import Group
from zntrack.state import NodeStatus
from zntrack.utils.misc import get_plugins_from_env
from zntrack.utils.misc import get_plugins_from_env, nwd_to_name

from .config import NOT_AVAILABLE, ZNTRACK_LAZY_VALUE, NodeStatusEnum
from .config import NOT_AVAILABLE, NWD_PATH, ZNTRACK_LAZY_VALUE, NodeStatusEnum

try:
from typing import dataclass_transform
Expand All @@ -32,6 +32,11 @@

def _name_setter(self, attr_name: str, value: str) -> None:
"""Check if the node name is valid."""
if attr_name in self.__dict__:
raise AttributeError("Node name cannot be changed.")

if value is None:
return

if value is not None and not is_valid_name(value):
raise InvalidStageName
Expand All @@ -41,8 +46,26 @@ def _name_setter(self, attr_name: str, value: str) -> None:
"Node name should not contain '_'."
" This character is used for defining groups."
)
self.__dict__[attr_name] = value # only used to check if the name has been set once

graph = znflow.get_graph()
nwd = NWD_PATH / value # TODO: bad default value, will be wrong in `__post_init__`
if graph is not znflow.empty_graph:
graph.all_nwds.remove(self.__dict__["nwd"]) # remove the current nwd

if graph.active_group is None:
nwd = NWD_PATH / value
else:
nwd = NWD_PATH / "/".join(graph.active_group.names) / value

self.__dict__[attr_name] = value
if nwd in graph.all_nwds:
if graph.active_group is None:
name = value
else:
name = "_".join(graph.active_group.names) + "_" + value
raise ValueError(f"A node with the name '{name}' already exists.")
graph.all_nwds.add(nwd)
self.__dict__["nwd"] = nwd


def _name_getter(self, attr_name: str) -> str:
Expand All @@ -57,20 +80,12 @@ def _name_getter(self, attr_name: str) -> str:
str: The resolved node name.
"""
value = self.__dict__.get(attr_name) # Safer lookup with .get()
graph = znflow.get_graph()

# If value exists and the graph is either empty or not inside a group, return it
if value is not None:
if graph is znflow.empty_graph or graph.active_group is None:
return str(value)

# If no graph is active, return the class name as the default
if graph is znflow.empty_graph:
return str(self.__class__.__name__)

# Compute name based on project-wide node names
return str(graph.compute_all_node_names()[self.uuid])
if self.__dict__.get("nwd") is not None:
# can not use self.nwd in case of `tmp_path`
return nwd_to_name(self.__dict__["nwd"])
else:
return self.__class__.__name__


@dataclass_transform()
Expand Down Expand Up @@ -151,6 +166,21 @@ def from_rev(
lazy_values["name"] = name
lazy_values["always_changed"] = None # TODO: read the state from dvc.yaml
instance = cls(**lazy_values)
if remote is not None or rev is not None:
import dvc.api

with dvc.api.open("zntrack.json", repo=remote, rev=rev) as f:
conf = json.loads(f.read())
nwd = pathlib.Path(conf[name]["nwd"]["value"])
else:
try:
with open("zntrack.json") as f:
conf = json.load(f)
nwd = pathlib.Path(conf[name]["nwd"]["value"])
except FileNotFoundError:
# from_rev is called before a graph is built
nwd = NWD_PATH / name
instance.__dict__["nwd"] = nwd

# TODO: check if the node is finished or not.
instance.__dict__["state"] = NodeStatus(
Expand Down
81 changes: 37 additions & 44 deletions zntrack/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import logging
import os
import subprocess
import uuid

import tqdm
import yaml
import znflow

from zntrack import utils
from zntrack.config import NWD_PATH
from zntrack.group import Group
from zntrack.state import PLUGIN_LIST
from zntrack.utils.finalize import make_commit
Expand Down Expand Up @@ -50,50 +50,49 @@ def __init__(
deployment=deployment,
**kwargs,
)
self.all_nwds = set()
# keep track of all nwd paths, they should be unique, until
# https://github.com/zincware/ZnFlow/issues/132 can be used
# to set nwd directly as pk

def compute_all_node_names(self) -> dict[uuid.UUID, str]:
"""Compute the Node name based on existing nodes on the graph."""
all_nodes = [self.nodes[uuid]["value"] for uuid in self.nodes]
node_names = {}
for node in all_nodes:
custom_name = node.__dict__.get("name")
node_name = custom_name or node.__class__.__name__
if custom_name is not None and isinstance(custom_name, _FinalNodeNameString):
node_names[node.uuid] = custom_name
continue

if node.state.group is None:
if self.active_group is not None:
node_name = f"{'_'.join(self.active_group.names)}_{node_name}"
else:
node_name = f"{node_name}"
else:
node_name = f"{'_'.join(node.state.group.names)}_{node_name}"

if node_name in node_names.values():
if custom_name:
raise ValueError(
f"A node with the name '{node_name}' already exists."
)
i = 0
while True:
i += 1
if f"{node_name}_{i}" not in node_names.values():
node_name = f"{node_name}_{i}"
break
node_names[node.uuid] = _FinalNodeNameString(node_name)

return node_names

def add_node(self, node_for_adding, **attr):
def add_znflow_node(self, node_for_adding, **attr):
from zntrack import Node

if not isinstance(node_for_adding, Node):
raise ValueError(
f"Node must be an instance of 'zntrack.Node', not {type(node_for_adding)}"
"Node must be an instance of 'zntrack.Node',"
f" not {type(node_for_adding)}."
)
if node_for_adding._external_:
return super().add_znflow_node(node_for_adding)
# here we finalize the node name!
# It can only be updated once more via `MyNode(name=...)`
if self.active_group is None:
nwd = NWD_PATH / node_for_adding.__class__.__name__
else:
nwd = (
NWD_PATH
/ "/".join(self.active_group.names)
/ node_for_adding.__class__.__name__
)
if nwd in self.all_nwds:
postfix = 1
while True:
if self.active_group is None:
nwd = NWD_PATH / f"{node_for_adding.__class__.__name__}_{postfix}"
else:
nwd = (
NWD_PATH
/ "/".join(self.active_group.names)
/ f"{node_for_adding.__class__.__name__}_{postfix}"
)
if nwd not in self.all_nwds:
break
postfix += 1
self.all_nwds.add(nwd)
node_for_adding.__dict__["nwd"] = nwd

return super().add_node(node_for_adding, **attr)
return super().add_znflow_node(node_for_adding)

def __exit__(self, exc_type, exc_val, exc_tb):
try:
Expand All @@ -105,12 +104,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.nodes[node_uuid]["value"].__dict__["state"]["group"] = (
Group.from_znflow_group(group)
)

all_node_names = self.compute_all_node_names()
for node_uuid in self.nodes:
self.nodes[node_uuid]["value"].__dict__["name"] = all_node_names[
node_uuid
]
finally:
super().__exit__(exc_type, exc_val, exc_tb)

Expand Down
16 changes: 15 additions & 1 deletion zntrack/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from zntrack.add import DVCImportPath
from zntrack.utils.import_handler import import_handler

from ..config import ENV_FILE_PATH
from ..config import ENV_FILE_PATH, NWD_PATH


class RunDVCImportPathHandler(znflow.utils.IterableHandler):
Expand Down Expand Up @@ -133,3 +133,17 @@ def sort_and_deduplicate(data: list[str | dict[str, dict]]):
new_data.append(key)

return new_data


def nwd_to_name(nwd: pathlib.Path) -> str:
# Convert both paths to lists of parts
nwd_parts = nwd.parts
base_parts = NWD_PATH.parts

# Remove the common prefix (base path) from nwd_parts
if nwd_parts[: len(base_parts)] == base_parts:
rel_parts = nwd_parts[len(base_parts) :]
else:
raise ValueError(f"{nwd} does not start with {NWD_PATH}")

return "_".join(rel_parts)
40 changes: 8 additions & 32 deletions zntrack/utils/node_wd.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Helpers for the Node Working Directory (NWD)."""

import json
import logging
import os
import pathlib
import shutil
import typing as t
import warnings

import znflow.utils
import znjson

from zntrack.add import DVCImportPath
from zntrack.config import NWD_PATH, ZNTRACK_FILE_PATH, NodeStatusEnum
from zntrack.config import NWD_PATH

if t.TYPE_CHECKING:
from zntrack import Node
Expand Down Expand Up @@ -41,42 +40,19 @@ def move_nwd(target: pathlib.Path, destination: pathlib.Path) -> None:
def get_nwd(node: "Node") -> pathlib.Path:
"""Get the node working directory.
This is used instead of `node.nwd` because it allows
for parameters to define if the nwd should be created.
Arguments:
---------
node: Node
The node instance for which the nwd should be returned.
"""
try:
nwd = node.__dict__["nwd"]
return node.__dict__["nwd"]
except KeyError:
if node.name is None:
raise ValueError("Unable to determine node name.")
if (
node.state.remote is None
and node.state.rev is None
and node.state.state == NodeStatusEnum.FINISHED
):
nwd = pathlib.Path(NWD_PATH, node.name)
else:
try:
with node.state.fs.open(ZNTRACK_FILE_PATH) as f:
zntrack_config = json.load(f)
nwd = zntrack_config[node.name]["nwd"]
nwd = json.loads(json.dumps(nwd), cls=znjson.ZnDecoder)
except (FileNotFoundError, KeyError):
nwd = pathlib.Path(NWD_PATH, node.name)

if node.state.group is not None:
# strip the groups from node_name
to_replace = "_".join(node.state.group.names) + "_"
replacement = "/".join(node.state.group.names) + "/"
nwd = pathlib.Path(str(nwd).replace(to_replace, replacement))

return nwd
warnings.warn(
"Using the NWD outside a project context"
" can not guarantee unique directories."
)
return pathlib.Path(NWD_PATH, node.__class__.__name__)


class NWDReplaceHandler(znflow.utils.IterableHandler):
Expand Down

0 comments on commit 703d14a

Please sign in to comment.