Skip to content

Commit

Permalink
Explicitly state whether to return binary values.
Browse files Browse the repository at this point in the history
  • Loading branch information
BenediktBurger committed Jun 4, 2024
1 parent c53d71c commit 6ba7917
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 62 deletions.
16 changes: 12 additions & 4 deletions pyleco/utils/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
# THE SOFTWARE.
#

from __future__ import annotations
import logging
from threading import Thread, Event
from time import sleep
from typing import Callable, Optional
from typing import Any, Callable, Optional, Union

from ..core import PROXY_SENDING_PORT, COORDINATOR_PORT
from .pipe_handler import PipeHandler, CommunicatorPipe
Expand Down Expand Up @@ -132,22 +133,29 @@ def get_communicator(self, **kwargs) -> CommunicatorPipe:
kwargs.setdefault("timeout", self.timeout)
return self.message_handler.get_communicator(**kwargs)

def register_rpc_method(self, method: Callable, **kwargs) -> None:
def register_rpc_method(self, method: Callable[..., Any], **kwargs) -> None:
"""Register a method for calling with the current message handler.
If you restart the listening, you have to register the method anew.
"""
self.message_handler.register_rpc_method(method=method, **kwargs)

def register_binary_rpc_method(
self, method: Callable, accept_binary_input: bool = False, **kwargs
self,
method: Callable[..., Union[Any, tuple[Any, list[bytes]]]],
accept_binary_input: bool = False,
return_binary_output: bool = False,
**kwargs,
) -> None:
"""Register a binary method for calling with the current message handler.
If you restart the listening, you have to register the method anew.
"""
self.message_handler.register_binary_rpc_method(

Check warning on line 154 in pyleco/utils/listener.py

View check run for this annotation

Codecov / codecov/patch

pyleco/utils/listener.py#L154

Added line #L154 was not covered by tests
method=method, accept_binary_input=accept_binary_input, **kwargs
method=method,
accept_binary_input=accept_binary_input,
return_binary_output=return_binary_output,
**kwargs,
)

def stop_listen(self) -> None:
Expand Down
65 changes: 34 additions & 31 deletions pyleco/utils/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from json import JSONDecodeError
import logging
import time
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, TypeVar

import zmq

Expand All @@ -48,6 +48,9 @@
heartbeat_interval = 10 # s


ReturnValue = TypeVar("ReturnValue")


class MessageHandler(BaseCommunicator, ExtendedComponentProtocol):
"""Maintain connection to the Coordinator and listen to incoming messages.
Expand Down Expand Up @@ -135,65 +138,65 @@ def setup_socket(self, host: str, port: int, protocol: str, context: zmq.Context
self.log.info(f"MessageHandler connecting to {host}:{port}")
self.socket.connect(f"{protocol}://{host}:{port}")

def register_rpc_method(self, method: Callable, **kwargs) -> None:
def register_rpc_method(self, method: Callable[..., Any], **kwargs) -> None:
"""Register a method to be available via rpc calls."""
self.rpc.method(**kwargs)(method)

def _handle_possible_binary_return_value(
self, return_value: Union[Any, bytes, list[bytes]]
) -> Optional[Any]:
if isinstance(return_value, (bytearray, bytes, memoryview)):
self.additional_response_payload = [return_value]
return None
elif isinstance(return_value, list) and isinstance(
return_value[0], (bytearray, bytes, memoryview)
):
self.additional_response_payload = return_value
return None
else:
return return_value
def _handle_binary_return_value(
self, return_value: tuple[ReturnValue, list[bytes]]
) -> ReturnValue:
self.additional_response_payload = return_value[1]
return return_value[0]

@staticmethod
def _pass_through(return_value: ReturnValue) -> ReturnValue:
return return_value

def _generate_binary_capable_method(
self,
method: Callable[..., Union[Any, bytes, list[bytes]]],
method: Callable[..., Union[ReturnValue, tuple[ReturnValue, list[bytes]]]],
accept_binary_input: bool = False,
) -> Callable[..., Any]:
return_binary_output: bool = False,
) -> Callable[..., ReturnValue]:
returner = self._handle_binary_return_value if return_binary_output else self._pass_through
if accept_binary_input is True:

@wraps(method)
def modified_method(*args, **kwargs): # type: ignore
def modified_method(*args, **kwargs) -> ReturnValue: # type: ignore
return_value = method(
*args, additional_payload=self.current_message.payload[1:], **kwargs
)
return self._handle_possible_binary_return_value(return_value=return_value)
return returner(return_value=return_value) # type: ignore
else:

@wraps(method)
def modified_method(*args, **kwargs):
def modified_method(*args, **kwargs) -> ReturnValue:
return_value = method(*args, **kwargs)
return self._handle_possible_binary_return_value(return_value=return_value)
return returner(return_value=return_value) # type: ignore
try:
modified_method.__doc__ += "\n(binary method)" # type: ignore
modified_method.__doc__ += "\n(binary method)" # type: ignore[operator]
except TypeError:
modified_method.__doc__ = "binary method"
return modified_method
modified_method.__doc__ = "(binary method)"
return modified_method # type: ignore

def register_binary_rpc_method(
self,
method: Callable[..., Union[Any, bytes, list[bytes]]],
method: Callable[..., Union[Any, tuple[Any, list[bytes]]]],
accept_binary_input: bool = False,
return_binary_output: bool = False,
**kwargs,
) -> None:
"""Register a method which accepts binary input and/or returns binary values.
If a method should accept binary input, set the `accept_binary_input=True` and the method
must accept the additional payload as an `additional_payload` parameter.
If a method returns a binary object or a list of binary objects, they are sent as
additional_payload with the json response value `None`.
:param accept_binary_input: the method must accept the additional payload as an
`additional_payload` parameter.
:param return_binary_output: the method must return a tuple of a JSON-able python object
(e.g. `None`) and of a list of bytes objects, to be sent as additional payload.
"""
modified_method = self._generate_binary_capable_method(
method=method, accept_binary_input=accept_binary_input
method=method,
accept_binary_input=accept_binary_input,
return_binary_output=return_binary_output,
)
self.register_rpc_method(modified_method, **kwargs)

Expand Down
8 changes: 5 additions & 3 deletions tests/acceptance_tests/test_director_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def binary_method_manually() -> None:
except IndexError:
pass

def binary_method_created(additional_payload: list[bytes]) -> bytes:
def binary_method_created(additional_payload: list[bytes]) -> tuple[None, list[bytes]]:
"""Receive binary data and return it."""
return additional_payload[0] * 2
return None, [additional_payload[0] * 2]

actor.register_rpc_method(binary_method_manually)
actor.register_binary_rpc_method(binary_method_created, accept_binary_input=True)
actor.register_binary_rpc_method(
binary_method_created, accept_binary_input=True, return_binary_output=True
)
actor.connect()
actor.rpc.method()(actor.device.triple)
actor.register_device_method(actor.device.triple)
Expand Down
51 changes: 27 additions & 24 deletions tests/utils/test_message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,36 +540,19 @@ def test_binary_payload_sent(self, handler_b: MessageHandler):
assert response.data == {"jsonrpc": "2.0", "id": 8, "result": None}


@pytest.mark.parametrize(
"return_value",
("asfd", 123.456, ["abc", 3, 9], 90),
)
def test_handle_possible_binary_return_value_unmodified_json(handler: MessageHandler, return_value):
result = handler._handle_possible_binary_return_value(return_value)
assert result == return_value
assert handler.additional_response_payload is None

@pytest.mark.parametrize(
"return_value, payload",
(
(b"abcd", [b"abcd"]),
([b"ab"], [b"ab"]),
([b"ab", b"cd"], [b"ab", b"cd"]),
),
)
def test_handle_possible_binary_return_value_with_binary(
handler: MessageHandler, return_value, payload
):
result = handler._handle_possible_binary_return_value(return_value)
def test_handle_binary_return_value(handler: MessageHandler):
payload = [b"abc", b"def"]
result = handler._handle_binary_return_value((None, payload))
assert result is None
assert handler.additional_response_payload == payload


class Test_generate_binary_method:
@pytest.fixture
def binary_method(self):
def binary_method(index: int, additional_payload: list[bytes]) -> bytes:
def binary_method(index: int, additional_payload: list[bytes]) -> tuple[None, list[bytes]]:
"""Docstring of binary method."""
return additional_payload[index]
return None, [additional_payload[index]]
return binary_method

@pytest.fixture(params=(True, False))
Expand All @@ -578,7 +561,9 @@ def modified_binary_method(self, handler: MessageHandler, binary_method, request
"rec", "send", data=b"", additional_payload=[b"0", b"1", b"2", b"3"]
)
self._accept_binary_input = abi = request.param
mod = handler._generate_binary_capable_method(binary_method, accept_binary_input=abi)
mod = handler._generate_binary_capable_method(
binary_method, accept_binary_input=abi, return_binary_output=True
)
self.handler = handler
return mod

Expand All @@ -588,6 +573,14 @@ def test_name(self, binary_method, modified_binary_method):
def test_docstring(self, modified_binary_method, binary_method):
assert modified_binary_method.__doc__ == binary_method.__doc__ + "\n(binary method)"

def test_docstring_without_original_docstring(self, handler: MessageHandler):
def binary_method(additional_payload):
return 7
mod = handler._generate_binary_capable_method(
binary_method, accept_binary_input=True, return_binary_output=False
)
assert mod.__doc__ == "(binary method)"

def test_annotation(self, modified_binary_method, binary_method):
assert modified_binary_method.__annotations__ == binary_method.__annotations__

Expand All @@ -607,6 +600,16 @@ def test_functionality_args(self, modified_binary_method):
assert modified_binary_method(1, [b"0", b"1", b"2", b"3"]) is None
assert self.handler.additional_response_payload == [b"1"]

def test_no_binary_return(self, handler: MessageHandler):
handler.current_message = Message("rec", "send", data=b"", additional_payload=[b"0"])

def binary_method(additional_payload = None):
return 7
mod = handler._generate_binary_capable_method(
binary_method, accept_binary_input=True, return_binary_output=False
)
assert mod() == 7


class Test_listen:
@pytest.fixture
Expand Down

0 comments on commit 6ba7917

Please sign in to comment.