Skip to content

Commit

Permalink
Merge pull request #281 from pyiron/scrape_macro_output
Browse files Browse the repository at this point in the history
[patch] Scrape macro output
  • Loading branch information
liamhuber authored Apr 11, 2024
2 parents c3cfc15 + e20199f commit 3f09dad
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 33 deletions.
95 changes: 70 additions & 25 deletions pyiron_workflow/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from __future__ import annotations

from abc import ABC
from abc import ABC, abstractmethod
import inspect
import re
from typing import Any, get_args, get_type_hints, Literal, Optional, TYPE_CHECKING
import warnings

Expand All @@ -29,23 +30,22 @@ class Macro(Composite, ABC):
then builds a static IO interface for this graph.
This callable must use the macro object itself as the first argument (e.g. adding
nodes to it).
As with :class:`Workflow` objects, macros leverage `inputs_map` and `outputs_map` to
control macro-level IO access to child IO.
As with :class:`Workflow`, default behaviour is to expose all unconnected child IO.
The provided callable may optionally specify further args and kwargs, which are used
to pre-populate the macro with :class:`UserInput` nodes;
The provided callable may optionally specify further args and kwargs; these are
used to pre-populate the macro with :class:`UserInput` nodes, although they may
later be trimmed if the IO can be connected directly to child node IO without any
loss of functionality.
This can be especially helpful when more than one child node needs access to the
same input value.
Similarly, the callable may return any number of child nodes' output channels (or
the node itself in the case of single-output nodes) as long as a commensurate
number of labels for these outputs were provided to the class constructor.
These function-like definitions of the graph creator callable can be used
to build only input xor output, or both together.
Each that is used switches its IO map to a "whitelist" paradigm, so any I/O _not_
provided in the callable signature/return values and output labels will be disabled.
Manual modifications of the IO maps inside the callable always take priority over
this whitelisting behaviour, so you always retain full control over what IO is
exposed, and the whitelisting is only for your convenience.
Macro input channel labels are scraped from the signature of the graph creator;
for output, output labels can be provided explicitly as a class attribute or, as a
fallback, they are scraped from the graph creator code return statement (stripping
off the "{first argument}.", where {first argument} is whatever the name of the
first argument is.
Macro IO is _value linked_ to the child IO, so that their values stay synchronized,
but the child nodes of a macro form an isolated sub-graph.
Expand All @@ -60,15 +60,15 @@ class Macro(Composite, ABC):
If only _one_ of these is specified, you'll get an error, but if you've provided
both then no further checks of their validity/reasonableness are performed, so be
careful.
Unlike :class:`Workflow`, this execution flow automation is set up once at instantiation;
Unlike :class:`Workflow`, this execution flow automation is set up once at
instantiation;
If the macro is modified post-facto, you may need to manually re-invoke
:meth:`configure_graph_execution`.
Promises (in addition parent class promises):
- IO is...
- Only built at instantiation, after child node replacement, or at request, so
it is "static" for improved efficiency
- Statically defined at the class level
- By value, i.e. the macro has its own IO channel instances and children are
duly encapsulated inside their own sub-graph
- Value-linked to the values of their corresponding child nodes' IO -- i.e.
Expand Down Expand Up @@ -171,32 +171,35 @@ class Macro(Composite, ABC):
If there's a particular macro we're going to use again and again, we might want
to consider making a new class for it using the decorator, just like we do for
function nodes:
>>> @Macro.wrap.as_macro_node("three__result")
... def AddThreeMacro(macro, one__x):
... add_three_macro(macro, one__x=one__x)
function nodes. If no output labels are explicitly provided, these are scraped
from the function return value, just like for function nodes (except the
initial `macro.` (or whatever the first argument is named) on any return values
is ignored):
>>> @Macro.wrap.as_macro_node()
... def AddThreeMacro(macro, x):
... add_three_macro(macro, one__x=x)
... # We could also simply have decorated that function to begin with
... return macro.three
>>>
>>> macro = AddThreeMacro()
>>> macro(one__x=0).three__result
>>> macro(x=0).three
3
Alternatively (and not recommended) is to make a new child class of
:class:`Macro` that overrides the :meth:`graph_creator` arg such that
the same graph is always created.
>>> class AddThreeMacro(Macro):
... _provided_output_labels = ["three__result"]
... _provided_output_labels = ["three"]
...
... @staticmethod
... def graph_creator(macro, one__x):
... add_three_macro(macro, one__x=one__x)
... def graph_creator(macro, x):
... add_three_macro(macro, one__x=x)
... return macro.three
>>>
>>> macro = AddThreeMacro()
>>> macro(one__x=0).three__result
>>> macro(x=0).three
3
We can also modify an existing macro at runtime by replacing nodes within it, as
Expand Down Expand Up @@ -298,14 +301,24 @@ def __init__(

self.set_input_values(**kwargs)

@staticmethod
@abstractmethod
def graph_creator(self, *args, **kwargs) -> callable:
"""Build the graph the node will run."""

@classmethod
def _validate_output_labels(cls) -> tuple[str]:
"""
Ensure that output_labels, if provided, are commensurate with graph creator
return values, if provided, and return them as a tuple.
"""
graph_creator_returns = ParseOutput(cls.graph_creator).output
output_labels = cls._provided_output_labels
output_labels = cls._get_output_labels()
if output_labels is not None and len(set(output_labels)) != len(output_labels):
raise ValueError(
f"{cls.__name__} must not have degenerate output labels: "
f"{output_labels}"
)
if graph_creator_returns is not None or output_labels is not None:
error_suffix = (
f"but {cls.__name__} macro class got return values: "
Expand Down Expand Up @@ -368,8 +381,40 @@ def _get_output_labels(cls):
"""
Return output labels provided on the class if not None.
"""
if cls._provided_output_labels is None:
cls._scrape_output_labels()
return cls._provided_output_labels

@classmethod
def _scrape_output_labels(cls):
"""
Inspect :meth:`node_function` to scrape out strings representing the
returned values.
_Only_ works for functions with a single `return` expression in their body.
It will return expressions and function calls just fine, thus good practice is
to create well-named variables and return those so that the output labels stay
dot-accessible.
"""
parsed_outputs = ParseOutput(cls.graph_creator).output
if parsed_outputs is None:
cls._provided_output_labels = None
else:
self_argument = list(cls._input_args().keys())[0]
cleaned_labels = [
# Strip off the first argument, e.g. self.foo just becomes foo
re.sub(r"^" + re.escape(f"{self_argument}."), "", p)
for p in parsed_outputs
]
if any("." in label for label in cleaned_labels):
raise ValueError(
f"Tried to scrape cleaned labels for {cls.__name__}, but at least "
f"one of {cleaned_labels} still contains a '.' -- please provide "
f"explicit labels"
)
cls._provided_output_labels = cleaned_labels

@classmethod
def preview_input_channels(cls) -> dict[str, tuple[Any, Any]]:
"""
Expand Down
50 changes: 42 additions & 8 deletions tests/unit/test_macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,38 @@ def LabelsAndReturnsMatch(macro):

LabelsAndReturnsMatch() # Both is fine

@as_macro_node()
def OutputScrapedFromCleanReturn(macro):
macro.foo = macro.create.standard.UserInput()
my_out = macro.foo
return my_out

self.assertListEqual(
["my_out"],
list(OutputScrapedFromCleanReturn.preview_output_channels().keys()),
msg="Output labels should get scraped from code, just like for functions"
)

@as_macro_node()
def OutputScrapedFromFilteredReturn(macro):
macro.foo = macro.create.standard.UserInput()
return macro.foo

self.assertListEqual(
["foo"],
list(OutputScrapedFromFilteredReturn.preview_output_channels().keys()),
msg="The first, self-like argument, should get stripped from output labels"
)

with self.assertRaises(
ValueError,
msg="Return values shouldn't have extra dots"
):
@as_macro_node()
def ReturnHasDot(macro):
macro.foo = macro.create.standard.UserInput()
return macro.foo.outputs.user_input

with self.assertRaises(
ValueError,
msg="The number of output labels and return values must match"
Expand All @@ -442,20 +474,22 @@ def MissingReturn(macro):

with self.assertRaises(
TypeError,
msg="Output labels must be there if return values are"
msg="Return values must be there if output labels are"
):
@as_macro_node()
def MissingLabel(macro):
@as_macro_node("some_label")
def MissingReturn(macro):
macro.foo = macro.create.standard.UserInput()
return macro.foo

with self.assertRaises(
TypeError,
msg="Return values must be there if output labels are"
ValueError,
msg="Degenerate output labels should not be allowed"
):
@as_macro_node("some_label")
def MissingLabel(macro):
@as_macro_node()
def DegenerateOutput(macro):
macro.foo = macro.create.standard.UserInput()
macro.bar = macro.create.standard.UserInput(macro.foo)
bar = macro.foo
return bar, macro.bar

def test_functionlike_io_parsing(self):
"""
Expand Down

0 comments on commit 3f09dad

Please sign in to comment.