Skip to content

Commit

Permalink
Merge pull request #187 from pyiron/semantic_storage_path
Browse files Browse the repository at this point in the history
Semantic storage path
  • Loading branch information
liamhuber authored Jan 31, 2024
2 parents 6af0a7a + 9e389cb commit 4df004b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
12 changes: 7 additions & 5 deletions pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,11 +828,13 @@ def _takes_zero_arguments(self, callback):

@staticmethod
def _no_positional_args(func):
return all([
parameter.default != inspect.Parameter.empty
or parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in inspect.signature(func).parameters.values()
])
return all(
[
parameter.default != inspect.Parameter.empty
or parameter.kind == inspect.Parameter.VAR_KEYWORD
for parameter in inspect.signature(func).parameters.values()
]
)

@property
def callback(self) -> callable:
Expand Down
28 changes: 27 additions & 1 deletion pyiron_workflow/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class Composite(Node, ABC):
- Have no other parent
- Can be replaced in-place with another node that has commensurate IO
- Have their working directory nested inside the composite's
- Are disallowed from having a label that conflicts with any of the parent's
other methods or attributes
- The length of a composite instance is its number of child nodes
- Running the composite...
- Runs the child nodes (either using manually specified execution signals, or
Expand Down Expand Up @@ -635,30 +637,54 @@ def _child_signal_connections(
) -> list[tuple[tuple[str, str], tuple[str, str]]]:
return self._get_connections_as_strings(self._get_signals_input)

@property
def node_labels(self) -> tuple[str]:
return (n.label for n in self)

def __getstate__(self):
state = super().__getstate__()
# Store connections as strings
state["_child_data_connections"] = self._child_data_connections
state["_child_signal_connections"] = self._child_signal_connections
# Bidict implements a custom reconstructor that is not playing well with h5io

# Transform the IO maps into a datatype that plays well with h5io
# (Bidict implements a custom reconstructor, which hurts us)
state["_inputs_map"] = (
None if self._inputs_map is None else dict(self._inputs_map)
)
state["_outputs_map"] = (
None if self._outputs_map is None else dict(self._outputs_map)
)

# Remove the nodes container from the state and store each element (node) right
# in the state -- the labels are guaranteed to not be attributes already so
# this is safe, and it makes sure that the storage path matches the graph path
del state["nodes"]
state["node_labels"] = self.node_labels
for node in self:
state[node.label] = node
# This key is guaranteed to be available in the state, since children are
# forbidden from having labels that clash with their parent's __dir__
return state

def __setstate__(self, state):
# Purge child connection info from the state
child_data_connections = state.pop("_child_data_connections")
child_signal_connections = state.pop("_child_signal_connections")

# Transform the IO maps back into the right class (bidict)
state["_inputs_map"] = (
None if state["_inputs_map"] is None else bidict(state["_inputs_map"])
)
state["_outputs_map"] = (
None if state["_outputs_map"] is None else bidict(state["_outputs_map"])
)

# Reconstruct nodes from state
state["nodes"] = DotDict(
{label: state[label] for label in state.pop("node_labels")}
)

super().__setstate__(state)

# Nodes purge their _parent information in their __getstate__
Expand Down
45 changes: 23 additions & 22 deletions tests/unit/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,28 +340,29 @@ def test_storage_values(self):
for storage_backend in ["h5io", "tinybase"]:
with self.subTest(storage_backend):
wf = Workflow("wf")
wf.register("static.demo_nodes", domain="demo")
wf.inp = wf.create.demo.AddThree(x=0)
wf.out = wf.inp.outputs.add_three + 1
wf_out = wf()
three_result = wf.inp.three.outputs.add.value

wf.save(backend=storage_backend)

reloaded = Workflow("wf", storage_backend=storage_backend)
self.assertEqual(
wf_out.out__add,
reloaded.outputs.out__add.value,
msg="Workflow-level data should get reloaded"
)
self.assertEqual(
three_result,
reloaded.inp.three.value,
msg="Child data arbitrarily deep should get reloaded"
)

# Clean up after ourselves
reloaded.storage.delete()
try:
wf.register("static.demo_nodes", domain="demo")
wf.inp = wf.create.demo.AddThree(x=0)
wf.out = wf.inp.outputs.add_three + 1
wf_out = wf()
three_result = wf.inp.three.outputs.add.value

wf.save(backend=storage_backend)

reloaded = Workflow("wf", storage_backend=storage_backend)
self.assertEqual(
wf_out.out__add,
reloaded.outputs.out__add.value,
msg="Workflow-level data should get reloaded"
)
self.assertEqual(
three_result,
reloaded.inp.three.value,
msg="Child data arbitrarily deep should get reloaded"
)
finally:
# Clean up after ourselves
wf.storage.delete()

def test_storage_scopes(self):
wf = Workflow("wf")
Expand Down

0 comments on commit 4df004b

Please sign in to comment.