From 1857963b8e7648ec86e3bd3de9f94dda12763a2c Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Tue, 6 Aug 2024 08:25:18 -0700 Subject: [PATCH] [patch] Multiple dispatch for function and macro decorators (#407) --- docs/README.md | 2 +- notebooks/deepdive.ipynb | 8 ++-- notebooks/quickstart.ipynb | 2 +- pyiron_workflow/nodes/composite.py | 4 +- pyiron_workflow/nodes/for_loop.py | 2 +- pyiron_workflow/nodes/function.py | 13 +++--- pyiron_workflow/nodes/macro.py | 14 +++--- pyiron_workflow/nodes/multiple_distpatch.py | 29 +++++++++++++ pyiron_workflow/nodes/standard.py | 4 +- pyiron_workflow/workflow.py | 2 +- tests/integration/test_provenance.py | 4 +- tests/static/demo_nodes.py | 2 +- tests/unit/nodes/test_for_loop.py | 4 +- tests/unit/nodes/test_function.py | 48 ++++++++++++++++++--- tests/unit/nodes/test_macro.py | 4 +- tests/unit/test_workflow.py | 2 +- 16 files changed, 106 insertions(+), 38 deletions(-) create mode 100644 pyiron_workflow/nodes/multiple_distpatch.py diff --git a/docs/README.md b/docs/README.md index e220327f..3474dd0d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -61,7 +61,7 @@ But the intent is to collect them together into a workflow and leverage existing ... def Permutations(n, choose=None): ... return math.perm(n, choose) >>> ->>> @Workflow.wrap.as_macro_node() +>>> @Workflow.wrap.as_macro_node ... def PermutationDifference(self, n, choose=None): ... self.p = Permutations(n, choose=choose) ... self.plus_1 = AddOne(n) diff --git a/notebooks/deepdive.ipynb b/notebooks/deepdive.ipynb index e0816e96..dfb381d8 100644 --- a/notebooks/deepdive.ipynb +++ b/notebooks/deepdive.ipynb @@ -715,7 +715,7 @@ } ], "source": [ - "@as_function_node()\n", + "@as_function_node\n", "def Linear(x):\n", " return x\n", "\n", @@ -1242,7 +1242,7 @@ "source": [ "wf = Workflow(\"simple\")\n", "\n", - "@Workflow.wrap.as_function_node()\n", + "@Workflow.wrap.as_function_node\n", "def AddOne(x):\n", " y = x + 1\n", " return y\n", @@ -1867,7 +1867,7 @@ } ], "source": [ - "@Workflow.wrap.as_macro_node()\n", + "@Workflow.wrap.as_macro_node\n", "def AddThree(macro, x: int = 0):\n", " \"\"\"\n", " The function decorator `as_macro_node` expects the decorated function \n", @@ -4530,7 +4530,7 @@ } ], "source": [ - "@Workflow.wrap.as_function_node()\n", + "@Workflow.wrap.as_function_node\n", "def FiveApart(a: int, b: int, c: int, d: int, e: str = \"foobar\"):\n", " return a, b, c, d, e,\n", "\n", diff --git a/notebooks/quickstart.ipynb b/notebooks/quickstart.ipynb index 5e404655..827b79c5 100644 --- a/notebooks/quickstart.ipynb +++ b/notebooks/quickstart.ipynb @@ -37,7 +37,7 @@ "metadata": {}, "outputs": [], "source": [ - "@Workflow.wrap.as_function_node()\n", + "@Workflow.wrap.as_function_node\n", "def AddOne(x):\n", " y = x + 1\n", " return y\n", diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 594505ab..d82d39fd 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -421,8 +421,8 @@ def graph_as_dict(self) -> dict: for n in self for out in n.signals.output for inp in out.connections - } - } + }, + }, } @property diff --git a/pyiron_workflow/nodes/for_loop.py b/pyiron_workflow/nodes/for_loop.py index 6ab10199..3b4f3c55 100644 --- a/pyiron_workflow/nodes/for_loop.py +++ b/pyiron_workflow/nodes/for_loop.py @@ -533,7 +533,7 @@ def for_node( Note that if we had simply returned each input individually, without any output labels on the node, we'd need to specify a map on the for-node so that the (looped) input and output columns on the resulting dataframe are all unique: - >>> @Workflow.wrap.as_function_node() + >>> @Workflow.wrap.as_function_node ... def FiveApart(a: int, b: int, c: int, d: int, e: str = "foobar"): ... return a, b, c, d, e, >>> diff --git a/pyiron_workflow/nodes/function.py b/pyiron_workflow/nodes/function.py index d779f2e8..e1302a78 100644 --- a/pyiron_workflow/nodes/function.py +++ b/pyiron_workflow/nodes/function.py @@ -7,6 +7,7 @@ from pyiron_snippets.factory import classfactory from pyiron_workflow.mixin.preview import ScrapesIO +from pyiron_workflow.nodes.multiple_distpatch import dispatch_output_labels from pyiron_workflow.nodes.static_io import StaticNode @@ -192,9 +193,10 @@ class Function(StaticNode, ScrapesIO, ABC): that fixes some of the node behaviour -- i.e. the :meth:`node_function`. This can be done most easily with the :func:`as_function_node` decorator, which - takes a function and returns a node class. It also allows us to provide labels - for the return values, :param:output_labels, which are otherwise scraped from - the text of the function definition: + takes a function and returns a node class. This can be used in the usual way, + but the decorator itself also optionally accepts some arguments. Namely, it + also allows us to provide labels for the return values, :param:output_labels, + which are otherwise scraped from the text of the function definition: >>> from pyiron_workflow import as_function_node >>> @@ -237,7 +239,7 @@ class Function(StaticNode, ScrapesIO, ABC): Let's put together a couple of nodes and then run in a "pull" paradigm to get the final node to run everything "upstream" then run itself: - >>> @as_function_node() + >>> @as_function_node ... def adder_node(x: int = 0, y: int = 0) -> int: ... sum = x + y ... return sum @@ -264,7 +266,7 @@ class Function(StaticNode, ScrapesIO, ABC): (like cyclic graphs). Here's our simple example from above using this other paradigm: - >>> @as_function_node() + >>> @as_function_node ... def adder_node(x: int = 0, y: int = 0) -> int: ... sum = x + y ... return sum @@ -391,6 +393,7 @@ def function_node_factory( ) +@dispatch_output_labels def as_function_node( *output_labels: str, validate_output_labels=True, diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index 9a0eeea9..97707192 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -15,6 +15,7 @@ from pyiron_workflow.mixin.has_interface_mixins import HasChannel from pyiron_workflow.io import Outputs, Inputs from pyiron_workflow.mixin.preview import ScrapesIO +from pyiron_workflow.nodes.multiple_distpatch import dispatch_output_labels from pyiron_workflow.nodes.static_io import StaticNode if TYPE_CHECKING: @@ -170,12 +171,12 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC): If there's a particular macro we're going to use again and again, we might want to consider making a new class for it using the decorator, just like we do for - function nodes. If no output labels are explicitly provided, these are scraped - from the function return value, just like for function nodes (except the - initial `macro.` (or whatever the first argument is named) on any return values - is ignored): + function nodes. If no output labels are explicitly provided as arguments to the + decorator itself, these are scraped from the function return value, just like + for function nodes (except the initial `macro` (or `self` or whatever the first + argument is named) on any return values is ignored): - >>> @Macro.wrap.as_macro_node() + >>> @Macro.wrap.as_macro_node ... def AddThreeMacro(self, x): ... add_three_macro(self, one__x=x) ... # We could also simply have decorated that function to begin with @@ -206,7 +207,7 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC): to do this. Let's explore these by going back to our `add_three_macro` and replacing each of its children with a node that adds 2 instead of 1. - >>> @Macro.wrap.as_function_node() + >>> @Macro.wrap.as_function_node ... def add_two(x): ... result = x + 2 ... return result @@ -511,6 +512,7 @@ def macro_node_factory( ) +@dispatch_output_labels def as_macro_node( *output_labels: str, validate_output_labels: bool = True, use_cache: bool = True ): diff --git a/pyiron_workflow/nodes/multiple_distpatch.py b/pyiron_workflow/nodes/multiple_distpatch.py new file mode 100644 index 00000000..6c4a71f0 --- /dev/null +++ b/pyiron_workflow/nodes/multiple_distpatch.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +""" +Shared code for various node-creating decorators facilitating multiple dispatch (i.e. +using decorators with or without arguments, contextually. +""" + + +class MultipleDispatchError(ValueError): + """ + Raise from callables using multiple dispatch when no interpretation of input + matches an expected case. + """ + + +def dispatch_output_labels(single_dispatch_decorator): + def multi_dispatch_decorator(*output_labels, **kwargs): + if len(output_labels) > 0 and callable(output_labels[0]): + if len(output_labels) > 1: + raise MultipleDispatchError( + f"Output labels must all be strings (for decorator usage with an " + f"argument), or a callable must be provided alone -- got " + f"{output_labels}." + ) + return single_dispatch_decorator(**kwargs)(output_labels[0]) + else: + return single_dispatch_decorator(*output_labels, **kwargs) + + return multi_dispatch_decorator diff --git a/pyiron_workflow/nodes/standard.py b/pyiron_workflow/nodes/standard.py index 49518d96..8a6d79e2 100644 --- a/pyiron_workflow/nodes/standard.py +++ b/pyiron_workflow/nodes/standard.py @@ -14,7 +14,7 @@ from pyiron_workflow.nodes.function import Function, as_function_node -@as_function_node() +@as_function_node def UserInput(user_input): """ Returns the user input as it is. @@ -177,7 +177,7 @@ def Int(x): return int(x) -@as_function_node() +@as_function_node def PureCall(fnc: callable): """ Return a call without any arguments diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index d34e58ae..5f80d851 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -70,7 +70,7 @@ class Workflow(ParentMost, Composite): >>> from pyiron_workflow.workflow import Workflow >>> - >>> @Workflow.wrap.as_function_node() + >>> @Workflow.wrap.as_function_node ... def fnc(x=0): ... return x + 1 >>> diff --git a/tests/integration/test_provenance.py b/tests/integration/test_provenance.py index 1cfbf6c0..b5947902 100644 --- a/tests/integration/test_provenance.py +++ b/tests/integration/test_provenance.py @@ -12,12 +12,12 @@ class TestProvenance(unittest.TestCase): """ def setUp(self) -> None: - @Workflow.wrap.as_function_node() + @Workflow.wrap.as_function_node def Slow(t): sleep(t) return t - @Workflow.wrap.as_macro_node() + @Workflow.wrap.as_macro_node def Provenance(self, t): self.fast = Workflow.create.standard.UserInput(t) self.slow = Slow(t) diff --git a/tests/static/demo_nodes.py b/tests/static/demo_nodes.py index 8f876a77..94591204 100644 --- a/tests/static/demo_nodes.py +++ b/tests/static/demo_nodes.py @@ -31,6 +31,6 @@ def dynamic(x): return x + 1 -Dynamic = Workflow.wrap.as_function_node()(dynamic) +Dynamic = Workflow.wrap.as_function_node(dynamic) nodes = [OptionallyAdd, AddThree, AddPlusOne, Dynamic] diff --git a/tests/unit/nodes/test_for_loop.py b/tests/unit/nodes/test_for_loop.py index cf0905c6..93d3b25c 100644 --- a/tests/unit/nodes/test_for_loop.py +++ b/tests/unit/nodes/test_for_loop.py @@ -229,7 +229,7 @@ def test_dynamic_length(self): ) def test_column_mapping(self): - @as_function_node() + @as_function_node def FiveApart( a: int = 0, b: int = 1, @@ -338,7 +338,7 @@ def test_body_node_executor(self): def test_with_connections(self): length_y = 3 - @as_macro_node() + @as_macro_node def LoopInside(self, x: list, y: int): self.to_list = inputs_to_list( length_y, y, y, y diff --git a/tests/unit/nodes/test_function.py b/tests/unit/nodes/test_function.py index 1d2212e5..43fb936f 100644 --- a/tests/unit/nodes/test_function.py +++ b/tests/unit/nodes/test_function.py @@ -3,8 +3,9 @@ import unittest from pyiron_workflow.channels import NOT_DATA -from pyiron_workflow.nodes.function import function_node, as_function_node +from pyiron_workflow.nodes.function import function_node, as_function_node, Function from pyiron_workflow.io import ConnectionCopyError, ValueCopyError +from pyiron_workflow.nodes.multiple_distpatch import MultipleDispatchError def throw_error(x: Optional[int] = None): @@ -157,11 +158,11 @@ def test_default_label(self): self.assertEqual(plus_one.__name__, n.label) def test_availability_of_node_function(self): - @as_function_node() + @as_function_node def linear(x): return x - @as_function_node() + @as_function_node def bilinear(x, y): xy = linear.node_function(x) * linear.node_function(y) return xy @@ -315,12 +316,12 @@ def plus_one_hinted(x: int = 0) -> int: ) def test_copy_values(self): - @as_function_node() + @as_function_node def reference(x=0, y: int = 0, z: int | float = 0, omega=None, extra_here=None): out = 42 return out - @as_function_node() + @as_function_node def all_floats(x=1.1, y=1.1, z=1.1, omega=NOT_DATA, extra_there=None) -> float: out = 42.1 return out @@ -365,7 +366,7 @@ def all_floats(x=1.1, y=1.1, z=1.1, omega=NOT_DATA, extra_there=None) -> float: # Note also that these nodes each have extra channels the other doesn't that # are simply ignored - @as_function_node() + @as_function_node def extra_channel(x=1, y=1, z=1, not_present=42): out = 42 return out @@ -493,7 +494,7 @@ def returns_foo() -> Foo: def test_void_return(self): """Test extensions to the `ScrapesIO` mixin.""" - @as_function_node() + @as_function_node def NoReturn(x): y = x + 1 @@ -516,6 +517,39 @@ def test_pickle(self): reloaded.outputs.to_value_dict() ) + def test_decoration(self): + with self.subTest("@as_function_node(*output_labels, ...)"): + WithDecoratorSignature = as_function_node("z")(plus_one) + self.assertTrue( + issubclass(WithDecoratorSignature, Function), + msg="Sanity check" + ) + self.assertListEqual( + ["z"], + list(WithDecoratorSignature.preview_outputs().keys()), + msg="Decorator should capture new output label" + ) + + with self.subTest("@as_function_node"): + WithoutDecoratorSignature = as_function_node(plus_one) + self.assertTrue( + issubclass(WithoutDecoratorSignature, Function), + msg="Sanity check" + ) + self.assertListEqual( + ["y"], # "Default" copied here from the function definition return + list(WithoutDecoratorSignature.preview_outputs().keys()), + msg="Decorator should capture new output label" + ) + + with self.assertRaises( + MultipleDispatchError, + msg="This shouldn't be accessible from a regular decorator usage pattern, " + "but make sure that mixing-and-matching argument-free calls and calls " + "directly providing the wrapped node fail cleanly" + ): + as_function_node(plus_one, "z") + if __name__ == '__main__': unittest.main() diff --git a/tests/unit/nodes/test_macro.py b/tests/unit/nodes/test_macro.py index 21433a9e..eb2145d2 100644 --- a/tests/unit/nodes/test_macro.py +++ b/tests/unit/nodes/test_macro.py @@ -539,7 +539,7 @@ def test_storage_for_modified_macros(self): def test_output_label_stripping(self): """Test extensions to the `ScrapesIO` mixin.""" - @as_macro_node() + @as_macro_node def OutputScrapedFromFilteredReturn(macro): macro.foo = macro.create.standard.UserInput() return macro.foo @@ -554,7 +554,7 @@ def OutputScrapedFromFilteredReturn(macro): ValueError, msg="Return values with extra dots are not permissible as scraped labels" ): - @as_macro_node() + @as_macro_node def ReturnHasDot(macro): macro.foo = macro.create.standard.UserInput() return macro.foo.outputs.user_input diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 053e0b9a..649ae46c 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -24,7 +24,7 @@ def PlusOne(x: int = 0): return x + 1 -@Workflow.wrap.as_function_node() +@Workflow.wrap.as_function_node def five(sleep_time=0.): sleep(sleep_time) five = 5