diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 72d9b27d..04f2d055 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -10,6 +10,7 @@ import typing from abc import ABC, abstractmethod +import inspect from warnings import warn from pyiron_workflow.has_channel import HasChannel @@ -728,11 +729,9 @@ def __round__(self): # Because we override __getattr__ we need to get and set state for serialization def __getstate__(self): - return self.__dict__ + return dict(self.__dict__) def __setstate__(self, state): - # Update instead of overriding in case some other attributes were added on the - # main process while a remote process was working away self.__dict__.update(**state) @@ -740,6 +739,7 @@ class SignalChannel(Channel, ABC): """ Signal channels give the option control execution flow by triggering callback functions when the channel is called. + Callbacks must be methods on the parent node that require no positional arguments. Inputs optionally accept an output signal on call, which output signals always send when they call their input connections. @@ -755,6 +755,10 @@ def __call__(self) -> None: pass +class BadCallbackError(ValueError): + pass + + class InputSignal(SignalChannel): @property def connection_partner_type(self): @@ -777,7 +781,36 @@ def __init__( object. """ super().__init__(label=label, node=node) - self.callback: callable = callback + if self._is_node_method(callback) and self._takes_zero_arguments(callback): + self._callback: str = callback.__name__ + else: + raise BadCallbackError( + f"The channel {self.label} on {self.node.label} got an unexpected " + f"callback: {callback}. " + f"Lives on node: {self._is_node_method(callback)}; " + f"take no args: {self._takes_zero_arguments(callback)} " + ) + + def _is_node_method(self, callback): + try: + return callback == getattr(self.node, callback.__name__) + except AttributeError: + return False + + def _takes_zero_arguments(self, callback): + return callable(callback) and self._no_positional_args(callback) + + @staticmethod + def _no_positional_args(func): + return all([ + parameter.default != inspect.Parameter.empty + or parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in inspect.signature(func).parameters.values() + ]) + + @property + def callback(self) -> callable: + return getattr(self.node, self._callback) def __call__(self, other: typing.Optional[OutputSignal] = None) -> None: self.callback() diff --git a/pyiron_workflow/composite.py b/pyiron_workflow/composite.py index 39e5a45f..8f4f0cd8 100644 --- a/pyiron_workflow/composite.py +++ b/pyiron_workflow/composite.py @@ -177,30 +177,26 @@ def on_run(self): return self.run_graph @staticmethod - def run_graph(_nodes: dict[Node], _starting_nodes: list[Node]): - for node in _starting_nodes: + def run_graph(_composite: Composite): + for node in _composite.starting_nodes: node.run() - return _nodes + return _composite @property def run_args(self) -> dict: - return {"_nodes": self.nodes, "_starting_nodes": self.starting_nodes} + return {"_composite": self} def process_run_result(self, run_output): - if run_output is not self.nodes: - # Then we probably ran on a parallel process and have an unpacked future - self._update_children(run_output) + if run_output is not self: + self._parse_remotely_executed_self(run_output) return DotDict(self.outputs.to_value_dict()) - def _update_children(self, children_from_another_process: DotDict[str, Node]): - """ - If you receive a new dictionary of children, e.g. from unpacking a futures - object of your own children you sent off to another process for computation, - replace your own nodes with them, and set yourself as their parent. - """ - for child in children_from_another_process.values(): - child._parent = self - self.nodes = children_from_another_process + def _parse_remotely_executed_self(self, other_self): + # Un-parent existing nodes before ditching them + for node in self: + node._parent = None + other_self.running = False # It's done now + self.__setstate__(other_self.__getstate__()) def disconnect_run(self) -> list[tuple[Channel, Channel]]: """ @@ -604,3 +600,10 @@ def tidy_working_directory(self): for node in self: node.tidy_working_directory() super().tidy_working_directory() + + def __setstate__(self, state): + super().__setstate__(state) + # Nodes purge their _parent information in their __getstate__ + # so return it to them: + for node in self: + node._parent = self diff --git a/pyiron_workflow/interfaces.py b/pyiron_workflow/interfaces.py index 90e19338..9e30390f 100644 --- a/pyiron_workflow/interfaces.py +++ b/pyiron_workflow/interfaces.py @@ -147,10 +147,10 @@ def __getitem__(self, item): ) from e def __getstate__(self): - return self.__dict__ + return dict(self.__dict__) def __setstate__(self, state): - self.__dict__ = state + self.__dict__.update(**state) def register(self, package_identifier: str, domain: Optional[str] = None) -> None: """ diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index 18ed8d01..f1cf6232 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -157,7 +157,7 @@ def to_dict(self): def __getstate__(self): # Compatibility with python <3.11 - return self.__dict__ + return dict(self.__dict__) def __setstate__(self, state): # Because we override getattr, we need to use __dict__ assignment directly in diff --git a/pyiron_workflow/macro.py b/pyiron_workflow/macro.py index 56da65fd..1790af94 100644 --- a/pyiron_workflow/macro.py +++ b/pyiron_workflow/macro.py @@ -7,7 +7,7 @@ from functools import partialmethod import inspect -from typing import get_type_hints, Literal, Optional +from typing import get_type_hints, Literal, Optional, TYPE_CHECKING from bidict import bidict @@ -17,6 +17,9 @@ from pyiron_workflow.io import Outputs, Inputs from pyiron_workflow.output_parser import ParseOutput +if TYPE_CHECKING: + from pyiron_workflow.channels import Channel + class Macro(Composite): """ @@ -471,9 +474,40 @@ def inputs(self) -> Inputs: def outputs(self) -> Outputs: return self._outputs - def _update_children(self, children_from_another_process): - super()._update_children(children_from_another_process) - self._rebuild_data_io() + def _parse_remotely_executed_self(self, other_self): + local_connection_data = [ + [(c, c.label, c.connections) for c in io_panel] + for io_panel in [ + self.inputs, + self.outputs, + self.signals.input, + self.signals.output, + ] + ] + + super()._parse_remotely_executed_self(other_self) + + for old_data, io_panel in zip( + local_connection_data, + [self.inputs, self.outputs, self.signals.input, self.signals.output] + # Get fresh copies of the IO panels post-update + ): + for original_channel, label, connections in old_data: + new_channel = io_panel[label] # Fetch it from the fresh IO panel + new_channel.connections = connections + for other_channel in connections: + self._replace_connection( + other_channel, original_channel, new_channel + ) + + @staticmethod + def _replace_connection( + channel: Channel, old_connection: Channel, new_connection: Channel + ): + """Brute-force replace an old connection in a channel with a new one""" + channel.connections = [ + c if c is not old_connection else new_connection for c in channel + ] def _configure_graph_execution(self): run_signals = self.disconnect_run() diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 8d0b615b..9bc949b1 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -1040,7 +1040,7 @@ def replace_with(self, other: Node | type[Node]): warnings.warn(f"Could not replace_node {self.label}, as it has no parent.") def __getstate__(self): - state = self.__dict__ + state = dict(self.__dict__) state["_parent"] = None # I am not at all confident that removing the parent here is the _right_ # solution. @@ -1073,7 +1073,7 @@ def __getstate__(self): # _but_ if the user is just passing instructions on how to _build_ an executor, # we'll trust that those serialize OK (this way we can, hopefully, eventually # support nesting executors!) - return self.__dict__ + return state def __setstate__(self, state): # Update instead of overriding in case some other attributes were added on the diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index 341badc8..05472438 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -2,7 +2,7 @@ from pyiron_workflow.channels import ( Channel, InputData, OutputData, InputSignal, AccumulatingInputSignal, OutputSignal, - NotData, ChannelConnectionError + NotData, ChannelConnectionError, BadCallbackError ) @@ -15,7 +15,6 @@ def __init__(self): def update(self): self.foo.append(self.foo[-1] + 1) - class InputChannel(Channel): """Just to de-abstract the base class""" def __str__(self): @@ -451,6 +450,54 @@ def test_aggregating_call(self): msg="All signals, including vestigial ones, should get cleared on call" ) + def test_callbacks(self): + class Extended(DummyNode): + def method_with_args(self, x): + return x + 1 + + def method_with_only_kwargs(self, x=0): + return x + 1 + + @staticmethod + def staticmethod_without_args(): + return 42 + + @staticmethod + def staticmethod_with_args(x): + return x + 1 + + @classmethod + def classmethod_without_args(cls): + return 42 + + @classmethod + def classmethod_with_args(cls, x): + return x + 1 + + def doesnt_belong_to_node(): + return 42 + + node = Extended() + with self.subTest("Callbacks that belong to the node and take no arguments"): + for callback in [ + node.update, + node.method_with_only_kwargs, + node.staticmethod_without_args, + node.classmethod_without_args + ]: + with self.subTest(callback.__name__): + InputSignal(label="inp", node=node, callback=callback) + + with self.subTest("Invalid callbacks"): + for callback in [ + node.method_with_args, + node.staticmethod_with_args, + node.classmethod_with_args, + doesnt_belong_to_node, + ]: + with self.subTest(callback.__name__): + with self.assertRaises(BadCallbackError): + InputSignal(label="inp", node=node, callback=callback) if __name__ == '__main__': unittest.main() diff --git a/tests/unit/test_io.py b/tests/unit/test_io.py index efae3620..8eb85e61 100644 --- a/tests/unit/test_io.py +++ b/tests/unit/test_io.py @@ -152,14 +152,18 @@ def test_to_list(self): class TestSignalIO(unittest.TestCase): def setUp(self) -> None: - node = DummyNode() + class Extended(DummyNode): + @staticmethod + def do_nothing(): + pass + + node = Extended() + - def do_nothing(): - pass signals = Signals() - signals.input.run = InputSignal("run", node, do_nothing) - signals.input.foo = InputSignal("foo", node, do_nothing) + signals.input.run = InputSignal("run", node, node.do_nothing) + signals.input.foo = InputSignal("foo", node, node.do_nothing) signals.output.ran = OutputSignal("ran", node) signals.output.bar = OutputSignal("bar", node) diff --git a/tests/unit/test_macro.py b/tests/unit/test_macro.py index 4afd1f3a..3b9cdcb0 100644 --- a/tests/unit/test_macro.py +++ b/tests/unit/test_macro.py @@ -238,6 +238,10 @@ def test_with_executor(self): returned_nodes = result.result(timeout=120) # Wait for the process to finish sleep(1) + self.assertFalse( + macro.running, + msg="Macro should be done running" + ) self.assertIsNot( original_one, returned_nodes.one, @@ -270,7 +274,7 @@ def test_with_executor(self): self.assertIs( downstream.inputs.x.connections[0], macro.outputs.three__result, - msg="The macro should still be connected to " + msg=f"The macro output should still be connected to downstream" ) sleep(0.2) # Give a moment for the ran signal to emit and downstream to run # I'm a bit surprised this sleep is necessary