Skip to content

Commit

Permalink
Merge pull request #282 from pyiron/static_io_parent
Browse files Browse the repository at this point in the history
[minor] Extract a parent class for pulling IO data from a class method
  • Loading branch information
liamhuber authored Apr 11, 2024
2 parents 3f09dad + d6ee64c commit 357afc2
Show file tree
Hide file tree
Showing 12 changed files with 1,376 additions and 1,199 deletions.
295 changes: 214 additions & 81 deletions notebooks/deepdive.ipynb

Large diffs are not rendered by default.

1,015 changes: 457 additions & 558 deletions notebooks/quickstart.ipynb

Large diffs are not rendered by default.

202 changes: 20 additions & 182 deletions pyiron_workflow/function.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
from __future__ import annotations

from abc import ABC, abstractmethod
import inspect
import warnings
from typing import Any, get_args, get_type_hints, Literal, Optional, TYPE_CHECKING
from typing import Any, Literal, Optional, TYPE_CHECKING

from pyiron_workflow.channels import InputData, NOT_DATA
from pyiron_workflow.channels import InputData
from pyiron_workflow.injection import OutputDataWithInjection
from pyiron_workflow.io import Inputs, Outputs
from pyiron_workflow.node import Node
from pyiron_workflow.output_parser import ParseOutput
from pyiron_workflow.io_preview import DecoratedNode, decorated_node_decorator_factory
from pyiron_workflow.snippets.colors import SeabornColors

if TYPE_CHECKING:
from pyiron_workflow.composite import Composite


class Function(Node, ABC):
class Function(DecoratedNode, ABC):
"""
Function nodes wrap an arbitrary python function.
Expand Down Expand Up @@ -301,8 +298,6 @@ class Function(Node, ABC):
guaranteed.
"""

_provided_output_labels: tuple[str] | None = None

def __init__(
self,
*args,
Expand Down Expand Up @@ -333,73 +328,14 @@ def node_function(*args, **kwargs) -> callable:
"""What the node _does_."""

@classmethod
def _type_hints(cls) -> dict:
"""The result of :func:`typing.get_type_hints` on the :meth:`node_function`."""
return get_type_hints(cls.node_function)

@classmethod
def preview_output_channels(cls) -> dict[str, Any]:
"""
Gives a class-level peek at the expected output channels.
Returns:
dict[str, tuple[Any, Any]]: The channel name and its corresponding type
hint.
"""
labels = cls._get_output_labels()
try:
type_hints = cls._type_hints()["return"]
if len(labels) > 1:
type_hints = get_args(type_hints)
if not isinstance(type_hints, tuple):
raise TypeError(
f"With multiple return labels expected to get a tuple of type "
f"hints, but got type {type(type_hints)}"
)
if len(type_hints) != len(labels):
raise ValueError(
f"Expected type hints and return labels to have matching "
f"lengths, but got {len(type_hints)} hints and "
f"{len(labels)} labels: {type_hints}, {labels}"
)
else:
# If there's only one hint, wrap it in a tuple, so we can zip it with
# *return_labels and iterate over both at once
type_hints = (type_hints,)
except KeyError: # If there are no return hints
type_hints = [None] * len(labels)
# Note that this nicely differs from `NoneType`, which is the hint when
# `None` is actually the hint!
return {label: hint for label, hint in zip(labels, type_hints)}
def _io_defining_function(cls) -> callable:
return cls.node_function

@classmethod
def _get_output_labels(cls):
"""
Return output labels provided on the class if not None, else scrape them from
:meth:`node_function`.
Note: When the user explicitly provides output channels, they are taking
responsibility that these are correct, e.g. in terms of quantity, order, etc.
"""
if cls._provided_output_labels is None:
return cls._scrape_output_labels()
else:
return cls._provided_output_labels

@classmethod
def _scrape_output_labels(cls):
"""
Inspect :meth:`node_function` to scrape out strings representing the
returned values.
_Only_ works for functions with a single `return` expression in their body.
It will return expressions and function calls just fine, thus good practice is
to create well-named variables and return those so that the output labels stay
dot-accessible.
"""
parsed_outputs = ParseOutput(cls.node_function).output
return [None] if parsed_outputs is None else parsed_outputs
def preview_outputs(cls) -> dict[str, Any]:
preview = super(Function, cls).preview_outputs()
return preview if len(preview) > 0 else {"None": type(None)}
# If clause facilitates functions with no return value

@property
def outputs(self) -> Outputs:
Expand All @@ -414,49 +350,9 @@ def _build_output_channels(self):
owner=self,
type_hint=hint,
)
for label, hint in self.preview_output_channels().items()
for label, hint in self.preview_outputs().items()
]

@classmethod
def preview_input_channels(cls) -> dict[str, tuple[Any, Any]]:
"""
Gives a class-level peek at the expected input channels.
Returns:
dict[str, tuple[Any, Any]]: The channel name and a tuple of its
corresponding type hint and default value.
"""
type_hints = cls._type_hints()
scraped: dict[str, tuple[Any, Any]] = {}
for label, param in cls._input_args().items():
if label in cls._init_keywords():
# We allow users to parse arbitrary kwargs as channel initialization
# So don't let them choose bad channel names
raise ValueError(
f"The Input channel name {label} is not valid. Please choose a "
f"name _not_ among {cls._init_keywords()}"
)

try:
type_hint = type_hints[label]
except KeyError:
type_hint = None

default = (
NOT_DATA if param.default is inspect.Parameter.empty else param.default
)

scraped[label] = (type_hint, default)
return scraped

@classmethod
def _input_args(cls):
return inspect.signature(cls.node_function).parameters

@classmethod
def _init_keywords(cls):
return list(inspect.signature(cls.__init__).parameters.keys())

@property
def inputs(self) -> Inputs:
if self._inputs is None:
Expand All @@ -471,7 +367,7 @@ def _build_input_channels(self):
default=default,
type_hint=type_hint,
)
for label, (type_hint, default) in self.preview_input_channels().items()
for label, (type_hint, default) in self.preview_inputs().items()
]

@property
Expand All @@ -495,7 +391,7 @@ def process_run_result(self, function_output: Any | tuple) -> Any | tuple:
return function_output

def _convert_input_args_and_kwargs_to_input_kwargs(self, *args, **kwargs):
reverse_keys = list(self._input_args().keys())[::-1]
reverse_keys = list(self._get_input_args().keys())[::-1]
if len(args) > len(reverse_keys):
raise ValueError(
f"Received {len(args)} positional arguments, but the node {self.label}"
Expand Down Expand Up @@ -559,6 +455,9 @@ def color(self) -> str:
return SeabornColors.green


as_function_node = decorated_node_decorator_factory(Function, Function.node_function)


def function_node(
node_function: callable,
*args,
Expand All @@ -569,6 +468,7 @@ def function_node(
storage_backend: Optional[Literal["h5io", "tinybase"]] = None,
save_after_run: bool = False,
output_labels: Optional[str | tuple[str]] = None,
validate_output_labels: bool = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -604,7 +504,9 @@ def function_node(
elif isinstance(output_labels, str):
output_labels = (output_labels,)

return as_function_node(*output_labels)(node_function)(
return as_function_node(
*output_labels, validate_output_labels=validate_output_labels
)(node_function)(
*args,
label=label,
parent=parent,
Expand All @@ -614,67 +516,3 @@ def function_node(
save_after_run=save_after_run,
**kwargs,
)


def as_function_node(*output_labels: str):
"""
A decorator for dynamically creating node classes from functions.
Decorates a function.
Returns a `Function` subclass whose name is the camel-case version of the function
node, and whose signature is modified to exclude the node function and output labels
(which are explicitly defined in the process of using the decorator).
Args:
*output_labels (str): A name for each return value of the node function OR an
empty tuple. When empty, scrapes output labels automatically from the
source code of the wrapped function. This can be useful when returned
values are not well named, e.g. to make the output channel dot-accessible
if it would otherwise have a label that requires item-string-based access.
Additionally, specifying a _single_ label for a wrapped function that
returns a tuple of values ensures that a _single_ output channel (holding
the tuple) is created, instead of one channel for each return value. The
default approach of extracting labels from the function source code also
requires that the function body contain _at most_ one `return` expression,
so providing explicit labels can be used to circumvent this
(at your own risk), or to circumvent un-inspectable source code (e.g. a
function that exists only in memory).
"""
output_labels = None if len(output_labels) == 0 else output_labels

# One really subtle thing is that we manually parse the function type hints right
# here and include these as a class-level attribute.
# This is because on (de)(cloud)pickling a function node, somehow the node function
# method attached to it gets its `__globals__` attribute changed; it retains stuff
# _inside_ the function, but loses imports it used from the _outside_ -- i.e. type
# hints! I (@liamhuber) don't deeply understand _why_ (de)pickling is modifying the
# __globals__ in this way, but the result is that type hints cannot be parsed after
# the change.
# The final piece of the puzzle here is that because the node function is a _class_
# level attribute, if you (de)pickle a node, _new_ instances of that node wind up
# having their node function's `__globals__` trimmed down in this way!
# So to keep the type hint parsing working, we snag and interpret all the type hints
# at wrapping time, when we are guaranteed to have all the globals available, and
# also slap them on as a class-level attribute. These get safely packed and returned
# when (de)pickling so we can keep processing type hints without trouble.
def as_node(node_function: callable):
node_class = type(
node_function.__name__,
(Function,), # Define parentage
{
"node_function": staticmethod(node_function),
"_provided_output_labels": output_labels,
"__module__": node_function.__module__,
},
)
try:
node_class.preview_output_channels()
except ValueError as e:
raise ValueError(
f"Failed to create a new {Function.__name__} child class "
f"dynamically from {node_function.__name__} -- probably due to a "
f"mismatch among output labels, returned values, and return type hints."
) from e
return node_class

return as_node
2 changes: 1 addition & 1 deletion pyiron_workflow/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import Any
import warnings

from pyiron_workflow.channels import (
Channel,
Expand Down
Loading

0 comments on commit 357afc2

Please sign in to comment.