Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callbacks as strings #174

Merged
merged 20 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import typing
from abc import ABC, abstractmethod
import inspect
from warnings import warn

from pyiron_workflow.has_channel import HasChannel
Expand Down Expand Up @@ -740,6 +741,7 @@ class SignalChannel(Channel, ABC):
"""
Signal channels give the option control execution flow by triggering callback
functions when the channel is called.
Callbacks must be methods on the parent node that require no positional arguments.
Inputs optionally accept an output signal on call, which output signals always
send when they call their input connections.

Expand All @@ -755,6 +757,10 @@ def __call__(self) -> None:
pass


class BadCallbackError(ValueError):
pass


class InputSignal(SignalChannel):
@property
def connection_partner_type(self):
Expand All @@ -777,7 +783,42 @@ def __init__(
object.
"""
super().__init__(label=label, node=node)
self.callback: callable = callback
if self._is_node_method(callback) and self._takes_zero_arguments(callback):
self._callback: str = callback.__name__
else:
raise BadCallbackError(
f"The channel {self.label} on {self.node.label} got an unexpected "
f"callback: {callback}. "
f"Lives on node: {self._is_node_method(callback)}; "
f"take no args: {self._takes_zero_arguments(callback)} "
)

def _is_node_method(self, callback):
try:
return callback == getattr(self.node, callback.__name__)
except AttributeError:
return False

def _takes_zero_arguments(self, callback):
return callable(callback) and self._no_positional_args(callback)

@staticmethod
def _no_positional_args(func):
return (
sum(
1
for parameter in inspect.signature(func).parameters.values()
if (
parameter.default == inspect.Parameter.empty
and parameter.kind != inspect._ParameterKind.VAR_KEYWORD
)
)
== 0
)

@property
def callback(self) -> callable:
return getattr(self.node, self._callback)

def __call__(self, other: typing.Optional[OutputSignal] = None) -> None:
self.callback()
Expand Down
51 changes: 49 additions & 2 deletions tests/unit/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pyiron_workflow.channels import (
Channel, InputData, OutputData, InputSignal, AccumulatingInputSignal, OutputSignal,
NotData, ChannelConnectionError
NotData, ChannelConnectionError, BadCallbackError
)


Expand All @@ -15,7 +15,6 @@ def __init__(self):
def update(self):
self.foo.append(self.foo[-1] + 1)


class InputChannel(Channel):
"""Just to de-abstract the base class"""
def __str__(self):
Expand Down Expand Up @@ -451,6 +450,54 @@ def test_aggregating_call(self):
msg="All signals, including vestigial ones, should get cleared on call"
)

def test_callbacks(self):
class Extended(DummyNode):
def method_with_args(self, x):
return x + 1

def method_with_only_kwargs(self, x=0):
return x + 1

@staticmethod
def staticmethod_without_args():
return 42

@staticmethod
def staticmethod_with_args(x):
return x + 1

@classmethod
def classmethod_without_args(cls):
return 42

@classmethod
def classmethod_with_args(cls, x):
return x + 1

def doesnt_belong_to_node():
return 42

node = Extended()
with self.subTest("Callbacks that belong to the node and take no arguments"):
for callback in [
node.update,
node.method_with_only_kwargs,
node.staticmethod_without_args,
node.classmethod_without_args
]:
with self.subTest(callback.__name__):
InputSignal(label="inp", node=node, callback=callback)

with self.subTest("Invalid callbacks"):
for callback in [
node.method_with_args,
node.staticmethod_with_args,
node.classmethod_with_args,
doesnt_belong_to_node,
]:
with self.subTest(callback.__name__):
with self.assertRaises(BadCallbackError):
InputSignal(label="inp", node=node, callback=callback)

if __name__ == '__main__':
unittest.main()
14 changes: 9 additions & 5 deletions tests/unit/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,18 @@ def test_to_list(self):

class TestSignalIO(unittest.TestCase):
def setUp(self) -> None:
node = DummyNode()
class Extended(DummyNode):
@staticmethod
def do_nothing():
pass

node = Extended()


def do_nothing():
pass

signals = Signals()
signals.input.run = InputSignal("run", node, do_nothing)
signals.input.foo = InputSignal("foo", node, do_nothing)
signals.input.run = InputSignal("run", node, node.do_nothing)
signals.input.foo = InputSignal("foo", node, node.do_nothing)
signals.output.ran = OutputSignal("ran", node)
signals.output.bar = OutputSignal("bar", node)

Expand Down
Loading