Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samwaseda committed Feb 11, 2025
1 parent 1323218 commit 1763d18
Showing 1 changed file with 67 additions and 54 deletions.
121 changes: 67 additions & 54 deletions pyiron_workflow/nodes/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pyiron_workflow.mixin.semantics import SemanticParent
from pyiron_workflow.node import Node
from pyiron_workflow.topology import set_run_connections_according_to_dag
from pyiron_workflow.nodes.export import export_to_dict
from pyiron_workflow.channels import Channel, NOT_DATA

Check failure on line 20 in pyiron_workflow/nodes/composite.py

View workflow job for this annotation

GitHub Actions / ruff-check

Ruff (I001)

pyiron_workflow/nodes/composite.py:6:1: I001 Import block is un-sorted or un-formatted

if TYPE_CHECKING:
from pyiron_workflow.channels import (
Expand Down Expand Up @@ -47,69 +47,82 @@ def _get_scoped_label(channel: Channel, io_: str) -> str:
return channel.scoped_label.replace("__", f".{io_}.")


def export_to_dict(workflow: Composite, with_values: bool = True) -> dict:
def export_node_to_dict(
node: Node, with_values: bool = True
) -> dict:
data = {"inputs": {}, "outputs": {}, "function": node.node_function}
for io_ in ["inputs", "outputs"]:
for inp in getattr(node, io_):
data[io_][inp.label] = _extract_data(inp)
return data


def export_composite_to_dict(
workflow: Composite, with_values: bool = True
) -> dict:
data = {"inputs": {}, "outputs": {}}
if isinstance(workflow, Composite):
data["nodes"] = {}
data["edges"] = []
for inp in workflow.inputs:
if inp.value_receiver is not None:
data["nodes"] = {}
data["edges"] = []
for inp in workflow.inputs:
if inp.value_receiver is not None:
data["edges"].append(
(
f"inputs.{inp.scoped_label}",
_get_scoped_label(inp.value_receiver, "inputs"),
)
)
for node in workflow:
label = node.label
if isinstance(node, Composite):
data["nodes"][label] = export_composite_to_dict(
node, with_values=with_values
)
else:
data["nodes"][label] = export_node_to_dict(
node, with_values=with_values
)
for out in node.outputs:
if _is_internal_connection(out, workflow, "inputs"):
data["edges"].append(
(
f"inputs.{inp.scoped_label}",
_get_scoped_label(inp.value_receiver, "inputs"),
_get_scoped_label(out, "outputs"),
_get_scoped_label(out.connections[0], "inputs"),
)
)
for node in workflow:
label = node.label
data["nodes"][label] = get_universal_dict(node, with_values=with_values)
for out in node.outputs:
if _is_internal_connection(out, workflow, "inputs"):
data["edges"].append(
(
_get_scoped_label(out, "outputs"),
_get_scoped_label(out.connections[0], "inputs"),
)
)
elif out.value_receiver is not None:
data["edges"].append(
(
_get_scoped_label(out, "outputs"),
f"outputs.{out.value_receiver.scoped_label}",
)
elif out.value_receiver is not None:
data["edges"].append(
(
_get_scoped_label(out, "outputs"),
f"outputs.{out.value_receiver.scoped_label}",
)
for io_ in ["inputs", "outputs"]:
for inp in getattr(workflow, io_):
data[io_][inp.scoped_label] = _extract_data(inp)
else:
for io_ in ["inputs", "outputs"]:
for inp in getattr(workflow, io_):
data[io_][inp.label] = _extract_data(inp)
data["function"] = workflow.node_function
)
for io_ in ["inputs", "outputs"]:
for inp in getattr(workflow, io_):
data[io_][inp.scoped_label] = _extract_data(inp)
return data


def _get_graph_as_dict(composite: Composite) -> dict:
if not isinstance(composite, Composite):
return composite
return {
"object": composite,
"nodes": {n.full_label: _get_graph_as_dict(n) for n in composite},
"edges": {
"data": {
(out.full_label, inp.full_label): (out, inp)
for n in composite
for out in n.outputs
for inp in out.connections
},
"signal": {
(out.full_label, inp.full_label): (out, inp)
for n in composite
for out in n.signals.output
for inp in out.connections
if not isinstance(composite, Composite):
return composite
return {
"object": composite,
"nodes": {n.full_label: _get_graph_as_dict(n) for n in composite},
"edges": {
"data": {
(out.full_label, inp.full_label): (out, inp)
for n in composite
for out in n.outputs
for inp in out.connections
},
"signal": {
(out.full_label, inp.full_label): (out, inp)
for n in composite
for out in n.signals.output
for inp in out.connections
},
},
},
}
}


class FailedChildError(RuntimeError):
Expand Down Expand Up @@ -523,7 +536,7 @@ def graph_as_dict(self) -> dict:
return _get_graph_as_dict(self)

def export_to_dict(self, with_values: bool = True) -> dict:
return export_to_dict(self, with_values=with_values)
return export_composite_to_dict(self, with_values=with_values)

def _get_connections_as_strings(
self, panel_getter: Callable
Expand Down

0 comments on commit 1763d18

Please sign in to comment.