From 938e3b09697fec9a6d1cff9a5fab651040221ef2 Mon Sep 17 00:00:00 2001 From: Benedikt Burger <67148916+BenediktBurger@users.noreply.github.com> Date: Tue, 4 Jun 2024 11:45:09 +0200 Subject: [PATCH] Explicitly state whether to return binary values. --- pyleco/utils/listener.py | 15 +++-- pyleco/utils/message_handler.py | 65 ++++++++++--------- tests/acceptance_tests/test_director_actor.py | 8 ++- tests/utils/test_message_handler.py | 51 ++++++++------- 4 files changed, 77 insertions(+), 62 deletions(-) diff --git a/pyleco/utils/listener.py b/pyleco/utils/listener.py index 2cd1002b..b2c1ddf6 100644 --- a/pyleco/utils/listener.py +++ b/pyleco/utils/listener.py @@ -25,7 +25,7 @@ import logging from threading import Thread, Event from time import sleep -from typing import Callable, Optional +from typing import Any, Callable, Optional, Union from ..core import PROXY_SENDING_PORT, COORDINATOR_PORT from .pipe_handler import PipeHandler, CommunicatorPipe @@ -132,7 +132,7 @@ def get_communicator(self, **kwargs) -> CommunicatorPipe: kwargs.setdefault("timeout", self.timeout) return self.message_handler.get_communicator(**kwargs) - def register_rpc_method(self, method: Callable, **kwargs) -> None: + def register_rpc_method(self, method: Callable[..., Any], **kwargs) -> None: """Register a method for calling with the current message handler. If you restart the listening, you have to register the method anew. @@ -140,14 +140,21 @@ def register_rpc_method(self, method: Callable, **kwargs) -> None: self.message_handler.register_rpc_method(method=method, **kwargs) def register_binary_rpc_method( - self, method: Callable, accept_binary_input: bool = False, **kwargs + self, + method: Callable[..., Union[Any, tuple[Any, list[bytes]]]], + accept_binary_input: bool = False, + return_binary_output: bool = False, + **kwargs, ) -> None: """Register a binary method for calling with the current message handler. If you restart the listening, you have to register the method anew. """ self.message_handler.register_binary_rpc_method( - method=method, accept_binary_input=accept_binary_input, **kwargs + method=method, + accept_binary_input=accept_binary_input, + return_binary_output=return_binary_output, + **kwargs, ) def stop_listen(self) -> None: diff --git a/pyleco/utils/message_handler.py b/pyleco/utils/message_handler.py index 4a64b3b6..e88b8058 100644 --- a/pyleco/utils/message_handler.py +++ b/pyleco/utils/message_handler.py @@ -27,7 +27,7 @@ from json import JSONDecodeError import logging import time -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, TypeVar import zmq @@ -48,6 +48,9 @@ heartbeat_interval = 10 # s +ReturnValue = TypeVar("ReturnValue") + + class MessageHandler(BaseCommunicator, ExtendedComponentProtocol): """Maintain connection to the Coordinator and listen to incoming messages. @@ -135,65 +138,65 @@ def setup_socket(self, host: str, port: int, protocol: str, context: zmq.Context self.log.info(f"MessageHandler connecting to {host}:{port}") self.socket.connect(f"{protocol}://{host}:{port}") - def register_rpc_method(self, method: Callable, **kwargs) -> None: + def register_rpc_method(self, method: Callable[..., Any], **kwargs) -> None: """Register a method to be available via rpc calls.""" self.rpc.method(**kwargs)(method) - def _handle_possible_binary_return_value( - self, return_value: Union[Any, bytes, list[bytes]] - ) -> Optional[Any]: - if isinstance(return_value, (bytearray, bytes, memoryview)): - self.additional_response_payload = [return_value] - return None - elif isinstance(return_value, list) and isinstance( - return_value[0], (bytearray, bytes, memoryview) - ): - self.additional_response_payload = return_value - return None - else: - return return_value + def _handle_binary_return_value( + self, return_value: tuple[ReturnValue, list[bytes]] + ) -> ReturnValue: + self.additional_response_payload = return_value[1] + return return_value[0] + + @staticmethod + def _pass_through(return_value: ReturnValue) -> ReturnValue: + return return_value def _generate_binary_capable_method( self, - method: Callable[..., Union[Any, bytes, list[bytes]]], + method: Callable[..., Union[ReturnValue, tuple[ReturnValue, list[bytes]]]], accept_binary_input: bool = False, - ) -> Callable[..., Any]: + return_binary_output: bool = False, + ) -> Callable[..., ReturnValue]: + returner = self._handle_binary_return_value if return_binary_output else self._pass_through if accept_binary_input is True: @wraps(method) - def modified_method(*args, **kwargs): # type: ignore + def modified_method(*args, **kwargs) -> ReturnValue: # type: ignore return_value = method( *args, additional_payload=self.current_message.payload[1:], **kwargs ) - return self._handle_possible_binary_return_value(return_value=return_value) + return returner(return_value=return_value) # type: ignore else: @wraps(method) - def modified_method(*args, **kwargs): + def modified_method(*args, **kwargs) -> ReturnValue: return_value = method(*args, **kwargs) - return self._handle_possible_binary_return_value(return_value=return_value) + return returner(return_value=return_value) # type: ignore try: - modified_method.__doc__ += "\n(binary method)" # type: ignore + modified_method.__doc__ += "\n(binary method)" # type: ignore[operator] except TypeError: - modified_method.__doc__ = "binary method" - return modified_method + modified_method.__doc__ = "(binary method)" + return modified_method # type: ignore def register_binary_rpc_method( self, - method: Callable[..., Union[Any, bytes, list[bytes]]], + method: Callable[..., Union[Any, tuple[Any, list[bytes]]]], accept_binary_input: bool = False, + return_binary_output: bool = False, **kwargs, ) -> None: """Register a method which accepts binary input and/or returns binary values. - If a method should accept binary input, set the `accept_binary_input=True` and the method - must accept the additional payload as an `additional_payload` parameter. - - If a method returns a binary object or a list of binary objects, they are sent as - additional_payload with the json response value `None`. + :param accept_binary_input: the method must accept the additional payload as an + `additional_payload` parameter. + :param return_binary_output: the method must return a tuple of a JSON-able python object + (e.g. `None`) and of a list of bytes objects, to be sent as additional payload. """ modified_method = self._generate_binary_capable_method( - method=method, accept_binary_input=accept_binary_input + method=method, + accept_binary_input=accept_binary_input, + return_binary_output=return_binary_output, ) self.register_rpc_method(modified_method, **kwargs) diff --git a/tests/acceptance_tests/test_director_actor.py b/tests/acceptance_tests/test_director_actor.py index 9a6b90c6..6421312e 100644 --- a/tests/acceptance_tests/test_director_actor.py +++ b/tests/acceptance_tests/test_director_actor.py @@ -79,12 +79,14 @@ def binary_method_manually() -> None: except IndexError: pass - def binary_method_created(additional_payload: list[bytes]) -> bytes: + def binary_method_created(additional_payload: list[bytes]) -> tuple[None, list[bytes]]: """Receive binary data and return it.""" - return additional_payload[0] * 2 + return None, [additional_payload[0] * 2] actor.register_rpc_method(binary_method_manually) - actor.register_binary_rpc_method(binary_method_created, accept_binary_input=True) + actor.register_binary_rpc_method( + binary_method_created, accept_binary_input=True, return_binary_output=True + ) actor.connect() actor.rpc.method()(actor.device.triple) actor.register_device_method(actor.device.triple) diff --git a/tests/utils/test_message_handler.py b/tests/utils/test_message_handler.py index 1b8b7869..de8cd884 100644 --- a/tests/utils/test_message_handler.py +++ b/tests/utils/test_message_handler.py @@ -540,36 +540,19 @@ def test_binary_payload_sent(self, handler_b: MessageHandler): assert response.data == {"jsonrpc": "2.0", "id": 8, "result": None} -@pytest.mark.parametrize( - "return_value", - ("asfd", 123.456, ["abc", 3, 9], 90), -) -def test_handle_possible_binary_return_value_unmodified_json(handler: MessageHandler, return_value): - result = handler._handle_possible_binary_return_value(return_value) - assert result == return_value - assert handler.additional_response_payload is None - -@pytest.mark.parametrize( - "return_value, payload", - ( - (b"abcd", [b"abcd"]), - ([b"ab"], [b"ab"]), - ([b"ab", b"cd"], [b"ab", b"cd"]), - ), -) -def test_handle_possible_binary_return_value_with_binary( - handler: MessageHandler, return_value, payload -): - result = handler._handle_possible_binary_return_value(return_value) +def test_handle_binary_return_value(handler: MessageHandler): + payload = [b"abc", b"def"] + result = handler._handle_binary_return_value((None, payload)) assert result is None assert handler.additional_response_payload == payload + class Test_generate_binary_method: @pytest.fixture def binary_method(self): - def binary_method(index: int, additional_payload: list[bytes]) -> bytes: + def binary_method(index: int, additional_payload: list[bytes]) -> tuple[None, list[bytes]]: """Docstring of binary method.""" - return additional_payload[index] + return None, [additional_payload[index]] return binary_method @pytest.fixture(params=(True, False)) @@ -578,7 +561,9 @@ def modified_binary_method(self, handler: MessageHandler, binary_method, request "rec", "send", data=b"", additional_payload=[b"0", b"1", b"2", b"3"] ) self._accept_binary_input = abi = request.param - mod = handler._generate_binary_capable_method(binary_method, accept_binary_input=abi) + mod = handler._generate_binary_capable_method( + binary_method, accept_binary_input=abi, return_binary_output=True + ) self.handler = handler return mod @@ -588,6 +573,14 @@ def test_name(self, binary_method, modified_binary_method): def test_docstring(self, modified_binary_method, binary_method): assert modified_binary_method.__doc__ == binary_method.__doc__ + "\n(binary method)" + def test_docstring_without_original_docstring(self, handler: MessageHandler): + def binary_method(additional_payload): + return 7 + mod = handler._generate_binary_capable_method( + binary_method, accept_binary_input=True, return_binary_output=False + ) + assert mod.__doc__ == "(binary method)" + def test_annotation(self, modified_binary_method, binary_method): assert modified_binary_method.__annotations__ == binary_method.__annotations__ @@ -607,6 +600,16 @@ def test_functionality_args(self, modified_binary_method): assert modified_binary_method(1, [b"0", b"1", b"2", b"3"]) is None assert self.handler.additional_response_payload == [b"1"] + def test_no_binary_return(self, handler: MessageHandler): + handler.current_message = Message("rec", "send", data=b"", additional_payload=[b"0"]) + + def binary_method(additional_payload = None): + return 7 + mod = handler._generate_binary_capable_method( + binary_method, accept_binary_input=True, return_binary_output=False + ) + assert mod() == 7 + class Test_listen: @pytest.fixture