Skip to content

Commit

Permalink
Merge pull request #568 from pyiron/graph_creator_as_method
Browse files Browse the repository at this point in the history
Make Macro.graph_creator a normal method
  • Loading branch information
XzzX authored Jan 21, 2025
2 parents ddc4e76 + b77210d commit 84487b3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
12 changes: 6 additions & 6 deletions pyiron_workflow/nodes/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from pyiron_snippets.factory import classfactory

from pyiron_workflow.compatibility import Self
from pyiron_workflow.io import Inputs
from pyiron_workflow.mixin.has_interface_mixins import HasChannel
from pyiron_workflow.mixin.injection import OutputsWithInjection
Expand Down Expand Up @@ -196,7 +197,6 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC):
>>> class AddThreeMacro(Macro):
... _output_labels = ["three"]
...
... @staticmethod
... def graph_creator(self, x):
... add_three_macro(self, one__x=x)
... return self.three
Expand Down Expand Up @@ -252,7 +252,7 @@ def _setup_node(self) -> None:
super()._setup_node()

ui_nodes = self._prepopulate_ui_nodes_from_graph_creator_signature()
returned_has_channel_objects = self.graph_creator(self, *ui_nodes)
returned_has_channel_objects = self.graph_creator(*ui_nodes)
if returned_has_channel_objects is None:
returned_has_channel_objects = ()
elif isinstance(returned_has_channel_objects, HasChannel):
Expand All @@ -271,10 +271,9 @@ def _setup_node(self) -> None:
remaining_ui_nodes = self._purge_single_use_ui_nodes(ui_nodes)
self._configure_graph_execution(remaining_ui_nodes)

@staticmethod
@abstractmethod
def graph_creator(
self: Macro, *args, **kwargs # noqa: PLW0211
self: Self, *args, **kwargs
) -> HasChannel | tuple[HasChannel, ...] | None:
"""Build the graph the node will run."""

Expand Down Expand Up @@ -480,7 +479,8 @@ def macro_node_factory(
Create a new :class:`Macro` subclass using the given graph creator function.
Args:
graph_creator (callable): Function to create the graph for the :class:`Macro`.
graph_creator (callable): Function to create the graph for this subclass of
:class:`Macro`.
validate_output_labels (bool): Whether to validate the output labels against
the return values of the wrapped function.
use_cache (bool): Whether nodes of this type should default to caching their
Expand All @@ -495,7 +495,7 @@ def macro_node_factory(
graph_creator.__name__,
(Macro,), # Define parentage
{
"graph_creator": staticmethod(graph_creator),
"graph_creator": graph_creator,
"__module__": graph_creator.__module__,
"__qualname__": graph_creator.__qualname__,
"_output_labels": None if len(output_labels) == 0 else output_labels,
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/nodes/test_macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ def test_creation_from_subclass(self):
class MyMacro(Macro):
_output_labels = ("three__result",)

@staticmethod
def graph_creator(self, one__x): # noqa: PLW0211
def graph_creator(self, one__x):
add_three_macro(self, one__x)
return self.three

Expand Down

0 comments on commit 84487b3

Please sign in to comment.