Skip to content

Commit

Permalink
[patch] Multiple dispatch for function and macro decorators (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
liamhuber authored Aug 6, 2024
1 parent 6c555a2 commit 1857963
Show file tree
Hide file tree
Showing 16 changed files with 106 additions and 38 deletions.
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ But the intent is to collect them together into a workflow and leverage existing
... def Permutations(n, choose=None):
... return math.perm(n, choose)
>>>
>>> @Workflow.wrap.as_macro_node()
>>> @Workflow.wrap.as_macro_node
... def PermutationDifference(self, n, choose=None):
... self.p = Permutations(n, choose=choose)
... self.plus_1 = AddOne(n)
Expand Down
8 changes: 4 additions & 4 deletions notebooks/deepdive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@
}
],
"source": [
"@as_function_node()\n",
"@as_function_node\n",
"def Linear(x):\n",
" return x\n",
"\n",
Expand Down Expand Up @@ -1242,7 +1242,7 @@
"source": [
"wf = Workflow(\"simple\")\n",
"\n",
"@Workflow.wrap.as_function_node()\n",
"@Workflow.wrap.as_function_node\n",
"def AddOne(x):\n",
" y = x + 1\n",
" return y\n",
Expand Down Expand Up @@ -1867,7 +1867,7 @@
}
],
"source": [
"@Workflow.wrap.as_macro_node()\n",
"@Workflow.wrap.as_macro_node\n",
"def AddThree(macro, x: int = 0):\n",
" \"\"\"\n",
" The function decorator `as_macro_node` expects the decorated function \n",
Expand Down Expand Up @@ -4530,7 +4530,7 @@
}
],
"source": [
"@Workflow.wrap.as_function_node()\n",
"@Workflow.wrap.as_function_node\n",
"def FiveApart(a: int, b: int, c: int, d: int, e: str = \"foobar\"):\n",
" return a, b, c, d, e,\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"metadata": {},
"outputs": [],
"source": [
"@Workflow.wrap.as_function_node()\n",
"@Workflow.wrap.as_function_node\n",
"def AddOne(x):\n",
" y = x + 1\n",
" return y\n",
Expand Down
4 changes: 2 additions & 2 deletions pyiron_workflow/nodes/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,8 @@ def graph_as_dict(self) -> dict:
for n in self
for out in n.signals.output
for inp in out.connections
}
}
},
},
}

@property
Expand Down
2 changes: 1 addition & 1 deletion pyiron_workflow/nodes/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def for_node(
Note that if we had simply returned each input individually, without any output
labels on the node, we'd need to specify a map on the for-node so that the
(looped) input and output columns on the resulting dataframe are all unique:
>>> @Workflow.wrap.as_function_node()
>>> @Workflow.wrap.as_function_node
... def FiveApart(a: int, b: int, c: int, d: int, e: str = "foobar"):
... return a, b, c, d, e,
>>>
Expand Down
13 changes: 8 additions & 5 deletions pyiron_workflow/nodes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pyiron_snippets.factory import classfactory

from pyiron_workflow.mixin.preview import ScrapesIO
from pyiron_workflow.nodes.multiple_distpatch import dispatch_output_labels
from pyiron_workflow.nodes.static_io import StaticNode


Expand Down Expand Up @@ -192,9 +193,10 @@ class Function(StaticNode, ScrapesIO, ABC):
that fixes some of the node behaviour -- i.e. the :meth:`node_function`.
This can be done most easily with the :func:`as_function_node` decorator, which
takes a function and returns a node class. It also allows us to provide labels
for the return values, :param:output_labels, which are otherwise scraped from
the text of the function definition:
takes a function and returns a node class. This can be used in the usual way,
but the decorator itself also optionally accepts some arguments. Namely, it
also allows us to provide labels for the return values, :param:output_labels,
which are otherwise scraped from the text of the function definition:
>>> from pyiron_workflow import as_function_node
>>>
Expand Down Expand Up @@ -237,7 +239,7 @@ class Function(StaticNode, ScrapesIO, ABC):
Let's put together a couple of nodes and then run in a "pull" paradigm to get
the final node to run everything "upstream" then run itself:
>>> @as_function_node()
>>> @as_function_node
... def adder_node(x: int = 0, y: int = 0) -> int:
... sum = x + y
... return sum
Expand All @@ -264,7 +266,7 @@ class Function(StaticNode, ScrapesIO, ABC):
(like cyclic graphs).
Here's our simple example from above using this other paradigm:
>>> @as_function_node()
>>> @as_function_node
... def adder_node(x: int = 0, y: int = 0) -> int:
... sum = x + y
... return sum
Expand Down Expand Up @@ -391,6 +393,7 @@ def function_node_factory(
)


@dispatch_output_labels
def as_function_node(
*output_labels: str,
validate_output_labels=True,
Expand Down
14 changes: 8 additions & 6 deletions pyiron_workflow/nodes/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyiron_workflow.mixin.has_interface_mixins import HasChannel
from pyiron_workflow.io import Outputs, Inputs
from pyiron_workflow.mixin.preview import ScrapesIO
from pyiron_workflow.nodes.multiple_distpatch import dispatch_output_labels
from pyiron_workflow.nodes.static_io import StaticNode

if TYPE_CHECKING:
Expand Down Expand Up @@ -170,12 +171,12 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC):
If there's a particular macro we're going to use again and again, we might want
to consider making a new class for it using the decorator, just like we do for
function nodes. If no output labels are explicitly provided, these are scraped
from the function return value, just like for function nodes (except the
initial `macro.` (or whatever the first argument is named) on any return values
is ignored):
function nodes. If no output labels are explicitly provided as arguments to the
decorator itself, these are scraped from the function return value, just like
for function nodes (except the initial `macro` (or `self` or whatever the first
argument is named) on any return values is ignored):
>>> @Macro.wrap.as_macro_node()
>>> @Macro.wrap.as_macro_node
... def AddThreeMacro(self, x):
... add_three_macro(self, one__x=x)
... # We could also simply have decorated that function to begin with
Expand Down Expand Up @@ -206,7 +207,7 @@ class Macro(Composite, StaticNode, ScrapesIO, ABC):
to do this. Let's explore these by going back to our `add_three_macro` and
replacing each of its children with a node that adds 2 instead of 1.
>>> @Macro.wrap.as_function_node()
>>> @Macro.wrap.as_function_node
... def add_two(x):
... result = x + 2
... return result
Expand Down Expand Up @@ -511,6 +512,7 @@ def macro_node_factory(
)


@dispatch_output_labels
def as_macro_node(
*output_labels: str, validate_output_labels: bool = True, use_cache: bool = True
):
Expand Down
29 changes: 29 additions & 0 deletions pyiron_workflow/nodes/multiple_distpatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

"""
Shared code for various node-creating decorators facilitating multiple dispatch (i.e.
using decorators with or without arguments, contextually.
"""


class MultipleDispatchError(ValueError):
"""
Raise from callables using multiple dispatch when no interpretation of input
matches an expected case.
"""


def dispatch_output_labels(single_dispatch_decorator):
def multi_dispatch_decorator(*output_labels, **kwargs):
if len(output_labels) > 0 and callable(output_labels[0]):
if len(output_labels) > 1:
raise MultipleDispatchError(
f"Output labels must all be strings (for decorator usage with an "
f"argument), or a callable must be provided alone -- got "
f"{output_labels}."
)
return single_dispatch_decorator(**kwargs)(output_labels[0])
else:
return single_dispatch_decorator(*output_labels, **kwargs)

return multi_dispatch_decorator
4 changes: 2 additions & 2 deletions pyiron_workflow/nodes/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pyiron_workflow.nodes.function import Function, as_function_node


@as_function_node()
@as_function_node
def UserInput(user_input):
"""
Returns the user input as it is.
Expand Down Expand Up @@ -177,7 +177,7 @@ def Int(x):
return int(x)


@as_function_node()
@as_function_node
def PureCall(fnc: callable):
"""
Return a call without any arguments
Expand Down
2 changes: 1 addition & 1 deletion pyiron_workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Workflow(ParentMost, Composite):
>>> from pyiron_workflow.workflow import Workflow
>>>
>>> @Workflow.wrap.as_function_node()
>>> @Workflow.wrap.as_function_node
... def fnc(x=0):
... return x + 1
>>>
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ class TestProvenance(unittest.TestCase):
"""

def setUp(self) -> None:
@Workflow.wrap.as_function_node()
@Workflow.wrap.as_function_node
def Slow(t):
sleep(t)
return t

@Workflow.wrap.as_macro_node()
@Workflow.wrap.as_macro_node
def Provenance(self, t):
self.fast = Workflow.create.standard.UserInput(t)
self.slow = Slow(t)
Expand Down
2 changes: 1 addition & 1 deletion tests/static/demo_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ def dynamic(x):
return x + 1


Dynamic = Workflow.wrap.as_function_node()(dynamic)
Dynamic = Workflow.wrap.as_function_node(dynamic)

nodes = [OptionallyAdd, AddThree, AddPlusOne, Dynamic]
4 changes: 2 additions & 2 deletions tests/unit/nodes/test_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_dynamic_length(self):
)

def test_column_mapping(self):
@as_function_node()
@as_function_node
def FiveApart(
a: int = 0,
b: int = 1,
Expand Down Expand Up @@ -338,7 +338,7 @@ def test_body_node_executor(self):
def test_with_connections(self):
length_y = 3

@as_macro_node()
@as_macro_node
def LoopInside(self, x: list, y: int):
self.to_list = inputs_to_list(
length_y, y, y, y
Expand Down
48 changes: 41 additions & 7 deletions tests/unit/nodes/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import unittest

from pyiron_workflow.channels import NOT_DATA
from pyiron_workflow.nodes.function import function_node, as_function_node
from pyiron_workflow.nodes.function import function_node, as_function_node, Function
from pyiron_workflow.io import ConnectionCopyError, ValueCopyError
from pyiron_workflow.nodes.multiple_distpatch import MultipleDispatchError


def throw_error(x: Optional[int] = None):
Expand Down Expand Up @@ -157,11 +158,11 @@ def test_default_label(self):
self.assertEqual(plus_one.__name__, n.label)

def test_availability_of_node_function(self):
@as_function_node()
@as_function_node
def linear(x):
return x

@as_function_node()
@as_function_node
def bilinear(x, y):
xy = linear.node_function(x) * linear.node_function(y)
return xy
Expand Down Expand Up @@ -315,12 +316,12 @@ def plus_one_hinted(x: int = 0) -> int:
)

def test_copy_values(self):
@as_function_node()
@as_function_node
def reference(x=0, y: int = 0, z: int | float = 0, omega=None, extra_here=None):
out = 42
return out

@as_function_node()
@as_function_node
def all_floats(x=1.1, y=1.1, z=1.1, omega=NOT_DATA, extra_there=None) -> float:
out = 42.1
return out
Expand Down Expand Up @@ -365,7 +366,7 @@ def all_floats(x=1.1, y=1.1, z=1.1, omega=NOT_DATA, extra_there=None) -> float:
# Note also that these nodes each have extra channels the other doesn't that
# are simply ignored

@as_function_node()
@as_function_node
def extra_channel(x=1, y=1, z=1, not_present=42):
out = 42
return out
Expand Down Expand Up @@ -493,7 +494,7 @@ def returns_foo() -> Foo:
def test_void_return(self):
"""Test extensions to the `ScrapesIO` mixin."""

@as_function_node()
@as_function_node
def NoReturn(x):
y = x + 1

Expand All @@ -516,6 +517,39 @@ def test_pickle(self):
reloaded.outputs.to_value_dict()
)

def test_decoration(self):
with self.subTest("@as_function_node(*output_labels, ...)"):
WithDecoratorSignature = as_function_node("z")(plus_one)
self.assertTrue(
issubclass(WithDecoratorSignature, Function),
msg="Sanity check"
)
self.assertListEqual(
["z"],
list(WithDecoratorSignature.preview_outputs().keys()),
msg="Decorator should capture new output label"
)

with self.subTest("@as_function_node"):
WithoutDecoratorSignature = as_function_node(plus_one)
self.assertTrue(
issubclass(WithoutDecoratorSignature, Function),
msg="Sanity check"
)
self.assertListEqual(
["y"], # "Default" copied here from the function definition return
list(WithoutDecoratorSignature.preview_outputs().keys()),
msg="Decorator should capture new output label"
)

with self.assertRaises(
MultipleDispatchError,
msg="This shouldn't be accessible from a regular decorator usage pattern, "
"but make sure that mixing-and-matching argument-free calls and calls "
"directly providing the wrapped node fail cleanly"
):
as_function_node(plus_one, "z")


if __name__ == '__main__':
unittest.main()
4 changes: 2 additions & 2 deletions tests/unit/nodes/test_macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def test_storage_for_modified_macros(self):
def test_output_label_stripping(self):
"""Test extensions to the `ScrapesIO` mixin."""

@as_macro_node()
@as_macro_node
def OutputScrapedFromFilteredReturn(macro):
macro.foo = macro.create.standard.UserInput()
return macro.foo
Expand All @@ -554,7 +554,7 @@ def OutputScrapedFromFilteredReturn(macro):
ValueError,
msg="Return values with extra dots are not permissible as scraped labels"
):
@as_macro_node()
@as_macro_node
def ReturnHasDot(macro):
macro.foo = macro.create.standard.UserInput()
return macro.foo.outputs.user_input
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def PlusOne(x: int = 0):
return x + 1


@Workflow.wrap.as_function_node()
@Workflow.wrap.as_function_node
def five(sleep_time=0.):
sleep(sleep_time)
five = 5
Expand Down

0 comments on commit 1857963

Please sign in to comment.