Skip to content

Commit

Permalink
Merge pull request #769 from pyiron/node_input_as_args
Browse files Browse the repository at this point in the history
Node input as args
  • Loading branch information
liamhuber authored Aug 2, 2023
2 parents 33f770c + 00362ba commit f0434bd
Show file tree
Hide file tree
Showing 7 changed files with 711 additions and 138 deletions.
368 changes: 277 additions & 91 deletions notebooks/workflow_example.ipynb

Large diffs are not rendered by default.

33 changes: 32 additions & 1 deletion pyiron_contrib/workflow/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class Composite(Node, ABC):
By default, `run()` will be called on all owned nodes have output connections but no
input connections (i.e. the upstream-most nodes), but this can be overridden to
specify particular nodes to use instead.
The `run()` method (and `update()`, and calling the workflow, when these result in
a run), return a new dot-accessible dictionary of keys and values created from the
composite output IO panel.
Does not specify `input` and `output` as demanded by the parent class; this
requirement is still passed on to children.
Expand Down Expand Up @@ -92,15 +95,33 @@ def __init__(
label: str,
*args,
parent: Optional[Composite] = None,
run_on_updates: bool = True,
strict_naming: bool = True,
**kwargs,
):
super().__init__(*args, label=label, parent=parent, **kwargs)
super().__init__(
*args,
label=label,
parent=parent,
run_on_updates=run_on_updates,
**kwargs
)
self.strict_naming: bool = strict_naming
self.nodes: DotDict[str:Node] = DotDict()
self.add: NodeAdder = NodeAdder(self)
self.starting_nodes: None | list[Node] = None

@property
def executor(self) -> None:
return None

@executor.setter
def executor(self, new_executor):
if new_executor is not None:
raise NotImplementedError(
"Running composite nodes with an executor is not yet supported"
)

def to_dict(self):
return {
"label": self.label,
Expand All @@ -115,12 +136,22 @@ def upstream_nodes(self) -> list[Node]:
if node.outputs.connected and not node.inputs.connected
]

@property
def on_run(self):
return self.run_graph

@staticmethod
def run_graph(self):
starting_nodes = (
self.upstream_nodes if self.starting_nodes is None else self.starting_nodes
)
for node in starting_nodes:
node.run()
return DotDict(self.outputs.to_value_dict())

@property
def run_args(self) -> dict:
return {"self": self}

def add_node(self, node: Node, label: Optional[str] = None) -> None:
"""
Expand Down
73 changes: 61 additions & 12 deletions pyiron_contrib/workflow/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class Function(Node):
Function nodes wrap an arbitrary python function.
Node IO, including type hints, is generated automatically from the provided
function.
Input data for the wrapped function can be provided as any valid combination of
`*arg` and `**kwarg` at both initialization and on calling the node.
On running, the function node executes this wrapped function with its current input
and uses the results to populate the node output.
Expand Down Expand Up @@ -64,6 +66,15 @@ class Function(Node):
call, such that output data gets pushed after the node stops running but before
then `ran` signal fires: run, process and push result, ran.
After a node is instantiated, its input can be updated as `*args` and/or `**kwargs`
on call.
This invokes an `update()` call, which can in turn invoke `run()` if
`run_on_updates` is set to `True`.
`run()` returns the output of the executed function, or a futures object if the
node is set to use an executor.
Calling the node or executing an `update()` returns the same thing as running, if
the node is run, or `None` if it is not set to run on updates or not ready to run.
Args:
node_function (callable): The function determining the behaviour of the node.
label (str): The node's label. (Defaults to the node function's name.)
Expand Down Expand Up @@ -155,6 +166,14 @@ class Function(Node):
>>> plus_minus_1.outputs.to_value_dict()
{'p1': 2, 'm1': 1}
Input data can be provided to both initialization and on call as ordered args
or keyword kwargs.
When running, updating, or calling the node, the output of the wrapped function
(if it winds up getting run in the conditional cases of updating and calling) is
returned:
>>> plus_minus_1(2, y=3)
(3, 2)
Finally, we might stop these updates from happening automatically, even when
all the input data is present and available:
>>> plus_minus_1 = Function(
Expand All @@ -167,8 +186,7 @@ class Function(Node):
With these flags set, the node requires us to manually call a run:
>>> plus_minus_1.run()
>>> plus_minus_1.outputs.to_value_dict()
{'p1': 1, 'm1': -1}
(-1, 1)
So function nodes have the most basic level of protection that they won't run
if they haven't seen any input data.
Expand Down Expand Up @@ -335,6 +353,7 @@ class Function(Node):
def __init__(
self,
node_function: callable,
*args,
label: Optional[str] = None,
run_on_updates: bool = True,
update_on_instantiation: bool = True,
Expand All @@ -346,6 +365,7 @@ def __init__(
super().__init__(
label=label if label is not None else node_function.__name__,
parent=parent,
run_on_updates=run_on_updates,
# **kwargs,
)

Expand All @@ -365,14 +385,7 @@ def __init__(
)
self._verify_that_channels_requiring_update_all_exist()

self.run_on_updates = False
# Temporarily disable running on updates to set all initial values at once
for k, v in kwargs.items():
if k in self.inputs.labels:
self.inputs[k] = v
elif k not in self._init_keywords:
warnings.warn(f"The keyword '{k}' was received but not used.")
self.run_on_updates = run_on_updates # Restore provided value
self._batch_update_input(*args, **kwargs)

if update_on_instantiation:
self.update()
Expand Down Expand Up @@ -527,6 +540,12 @@ def on_run(self):
def run_args(self) -> dict:
kwargs = self.inputs.to_value_dict()
if "self" in self._input_args:
if self.executor is not None:
raise NotImplementedError(
f"The node {self.label} cannot be run on an executor because it "
f"uses the `self` argument and this functionality is not yet "
f"implemented"
)
kwargs["self"] = self
return kwargs

Expand All @@ -551,8 +570,34 @@ def process_run_result(self, function_output):
for out, value in zip(self.outputs, function_output):
out.update(value)

def __call__(self) -> None:
self.run()
def _convert_input_args_and_kwargs_to_input_kwargs(self, *args, **kwargs):
reverse_keys = list(self._input_args.keys())[::-1]
if len(args) > len(reverse_keys):
raise ValueError(
f"Received {len(args)} positional arguments, but the node {self.label}"
f"only accepts {len(reverse_keys)} inputs."
)

positional_keywords = reverse_keys[-len(args):] if len(args) > 0 else [] # -0:
if len(set(positional_keywords).intersection(kwargs.keys())) > 0:
raise ValueError(
f"Cannot use {set(positional_keywords).intersection(kwargs.keys())} "
f"as both positional _and_ keyword arguments; args {args}, kwargs {kwargs}, reverse_keys {reverse_keys}, positional_keyworkds {positional_keywords}"
)

for arg in args:
key = positional_keywords.pop()
kwargs[key] = arg

return kwargs

def _batch_update_input(self, *args, **kwargs):
kwargs = self._convert_input_args_and_kwargs_to_input_kwargs(*args, **kwargs)
return super()._batch_update_input(**kwargs)

def __call__(self, *args, **kwargs) -> None:
kwargs = self._convert_input_args_and_kwargs_to_input_kwargs(*args, **kwargs)
return super().__call__(**kwargs)

def to_dict(self):
return {
Expand All @@ -577,6 +622,7 @@ class Slow(Function):
def __init__(
self,
node_function: callable,
*args,
label: Optional[str] = None,
run_on_updates=False,
update_on_instantiation=False,
Expand All @@ -586,6 +632,7 @@ def __init__(
):
super().__init__(
node_function,
*args,
label=label,
run_on_updates=run_on_updates,
update_on_instantiation=update_on_instantiation,
Expand All @@ -608,6 +655,7 @@ class SingleValue(Function, HasChannel):
def __init__(
self,
node_function: callable,
*args,
label: Optional[str] = None,
run_on_updates=True,
update_on_instantiation=True,
Expand All @@ -617,6 +665,7 @@ def __init__(
):
super().__init__(
node_function,
*args,
label=label,
run_on_updates=run_on_updates,
update_on_instantiation=update_on_instantiation,
Expand Down
54 changes: 46 additions & 8 deletions pyiron_contrib/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from concurrent.futures import Future
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING

from pyiron_contrib.executors import CloudpickleProcessPoolExecutor
from pyiron_contrib.workflow.files import DirectoryObject
Expand Down Expand Up @@ -44,6 +45,16 @@ class Node(HasToDict, ABC):
By default, nodes' signals input comes with `run` and `ran` IO ports which force
the `run()` method and which emit after `finish_run()` is completed, respectfully.
The `run()` method returns a representation of the node output (possible a futures
object, if the node is running on an executor), and consequently `update()` also
returns this output if the node is `ready` and has `run_on_updates = True`.
Calling an already instantiated node allows its input channels to be updated using
keyword arguments corresponding to the channel labels, performing a batch-update of
all supplied input and then calling `update()`.
As such, calling the node _also_ returns a representation of the output (or `None`
if the node is not set to run on updates, or is otherwise unready to run).
Nodes have a status, which is currently represented by the `running` and `failed`
boolean flags.
Their value is controlled automatically in the defined `run` and `finish_run`
Expand Down Expand Up @@ -153,7 +164,7 @@ def outputs(self) -> Outputs:

@property
@abstractmethod
def on_run(self) -> callable[..., tuple]:
def on_run(self) -> callable[..., Any | tuple]:
"""
What the node actually does!
"""
Expand All @@ -166,7 +177,7 @@ def run_args(self) -> dict:
"""
return {}

def process_run_result(self, run_output: tuple) -> None:
def process_run_result(self, run_output: Any | tuple) -> None:
"""
What to _do_ with the results of `on_run` once you have them.
Expand All @@ -175,7 +186,7 @@ def process_run_result(self, run_output: tuple) -> None:
"""
pass

def run(self) -> None:
def run(self) -> Any | tuple | Future:
"""
Executes the functionality of the node defined in `on_run`.
Handles the status of the node, and communicating with any remote
Expand All @@ -194,18 +205,19 @@ def run(self) -> None:
self.running = False
self.failed = True
raise e
self.finish_run(run_output)
return self.finish_run(run_output)
elif isinstance(self.executor, CloudpickleProcessPoolExecutor):
self.future = self.executor.submit(self.on_run, **self.run_args)
self.future.add_done_callback(self.finish_run)
return self.future
else:
raise NotImplementedError(
"We currently only support executing the node functionality right on "
"the main python process or with a "
"pyiron_contrib.workflow.util.CloudpickleProcessPoolExecutor."
)

def finish_run(self, run_output: tuple | Future):
def finish_run(self, run_output: tuple | Future) -> Any | tuple:
"""
Switch the node status, process the run result, then fire the ran signal.
Expand All @@ -223,6 +235,7 @@ def finish_run(self, run_output: tuple | Future):
try:
self.process_run_result(run_output)
self.signals.output.ran()
return run_output
except Exception as e:
self.failed = True
raise e
Expand All @@ -233,9 +246,9 @@ def _build_signal_channels(self) -> Signals:
signals.output.ran = OutputSignal("ran", self)
return signals

def update(self) -> None:
def update(self) -> Any | tuple | Future | None:
if self.run_on_updates and self.ready:
self.run()
return self.run()

@property
def working_directory(self):
Expand Down Expand Up @@ -275,3 +288,28 @@ def fully_connected(self):
and self.outputs.fully_connected
and self.signals.fully_connected
)

def _batch_update_input(self, **kwargs):
"""
Temporarily disable running on updates to set all input values at once.
Args:
**kwargs: input label - input value (including channels for connection)
pairs.
"""
run_on_updates, self.run_on_updates = self.run_on_updates, False
for k, v in kwargs.items():
if k in self.inputs.labels:
self.inputs[k] = v
else:
warnings.warn(
f"The keyword '{k}' was not found among input labels. If you are "
f"trying to update a node keyword, please use attribute assignment "
f"directly instead of calling, e.g. "
f"`my_node_instance.run_on_updates = False`."
)
self.run_on_updates = run_on_updates # Restore provided value

def __call__(self, **kwargs) -> None:
self._batch_update_input(**kwargs)
return self.update()
27 changes: 25 additions & 2 deletions pyiron_contrib/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ class Workflow(Composite):
>>> print(wf.outputs.second_y.value)
2
These input keys can be used when calling the workflow to update the input. In
our example, the nodes update automatically when their input gets updated, so
all we need to do to see updated workflow output is update the input:
>>> out = wf(first_x=10)
>>> out
{'second_y': 12}
Note: this _looks_ like a dictionary, but has some extra convenience that we
can dot-access data:
>>> out.second_y
12
Workflows also give access to packages of pre-built nodes under different
namespaces, e.g.
>>> wf = Workflow("with_prebuilt")
Expand Down Expand Up @@ -118,8 +130,19 @@ class Workflow(Composite):
integrity of workflows when they're used somewhere else?
"""

def __init__(self, label: str, *nodes: Node, strict_naming=True):
super().__init__(label=label, parent=None, strict_naming=strict_naming)
def __init__(
self,
label: str,
*nodes: Node,
run_on_updates: bool = True,
strict_naming=True
):
super().__init__(
label=label,
parent=None,
run_on_updates=run_on_updates,
strict_naming=strict_naming,
)

for node in nodes:
self.add_node(node)
Expand Down
Loading

0 comments on commit f0434bd

Please sign in to comment.