diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dac82fb..123e3562 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # CHANGELOG +## [unreleased] + +### Added + +* Add convenience functions for using additional frames for binary payload ([#82](https://github.com/pymeasure/pyleco/pull/82)) + + ## [0.3.2] 2024-5-07 ### Fixed diff --git a/pyleco/core/data_message.py b/pyleco/core/data_message.py index e61476c0..3737bf0d 100644 --- a/pyleco/core/data_message.py +++ b/pyleco/core/data_message.py @@ -24,7 +24,7 @@ from __future__ import annotations from json import JSONDecodeError -from typing import Any, Optional, Union +from typing import Any, Iterable, Optional, Union from .serialization import deserialize_data, generate_conversation_id, serialize_data, MessageTypes @@ -42,6 +42,7 @@ def __init__(self, data: Optional[Union[bytes, str, Any]] = None, conversation_id: Optional[bytes] = None, message_type: Union[MessageTypes, int] = MessageTypes.NOT_DEFINED, + additional_payload: Optional[Iterable[bytes]] = None, **kwargs) -> None: super().__init__(**kwargs) self.topic = topic.encode() if isinstance(topic, str) else topic @@ -61,6 +62,8 @@ def __init__(self, self.payload = [] else: self.payload = [serialize_data(data)] + if additional_payload is not None: + self.payload.extend(additional_payload) @classmethod def from_frames(cls, topic: bytes, header: bytes, *payload: bytes): @@ -71,8 +74,7 @@ def from_frames(cls, topic: bytes, header: bytes, *payload: bytes): frames = socket.recv_multipart() message = DataMessage.from_frames(*frames) """ - message = cls(topic=topic, header=header) - message.payload = list(payload) + message = cls(topic=topic, header=header, additional_payload=payload) return message def to_frames(self) -> list[bytes]: diff --git a/pyleco/core/internal_protocols.py b/pyleco/core/internal_protocols.py index 60ce25e2..7c6ae246 100644 --- a/pyleco/core/internal_protocols.py +++ b/pyleco/core/internal_protocols.py @@ -59,43 +59,76 @@ def sign_out(self) -> None: ... # pragma: no cover def send_message(self, message: Message) -> None: ... # pragma: no cover - def read_message(self, conversation_id: Optional[bytes], timeout: Optional[float] = None - ) -> Message: ... # pragma: no cover + def read_message( + self, conversation_id: Optional[bytes], timeout: Optional[float] = None + ) -> Message: ... # pragma: no cover - def ask_message(self, message: Message, timeout: Optional[float] = None - ) -> Message: ... # pragma: no cover + def ask_message( + self, message: Message, timeout: Optional[float] = None + ) -> Message: ... # pragma: no cover def close(self) -> None: ... # pragma: no cover # Utilities - def send(self, - receiver: Union[bytes, str], - conversation_id: Optional[bytes] = None, - data: Optional[Any] = None, - **kwargs) -> None: + def send( + self, + receiver: Union[bytes, str], + conversation_id: Optional[bytes] = None, + data: Optional[Any] = None, + **kwargs, + ) -> None: """Send a message based on kwargs.""" - self.send_message(message=Message( - receiver=receiver, conversation_id=conversation_id, data=data, **kwargs - )) - - def ask(self, receiver: Union[bytes, str], conversation_id: Optional[bytes] = None, - data: Optional[Any] = None, - timeout: Optional[float] = None, - **kwargs) -> Message: + self.send_message( + message=Message(receiver=receiver, conversation_id=conversation_id, data=data, **kwargs) + ) + + def ask( + self, + receiver: Union[bytes, str], + conversation_id: Optional[bytes] = None, + data: Optional[Any] = None, + timeout: Optional[float] = None, + **kwargs, + ) -> Message: """Send a message based on kwargs and retrieve the response.""" - return self.ask_message(message=Message( - receiver=receiver, conversation_id=conversation_id, data=data, **kwargs), - timeout=timeout) - - def interpret_rpc_response(self, response_message: Message) -> Any: - return self.rpc_generator.get_result_from_response(response_message.payload[0]) - - def ask_rpc(self, receiver: Union[bytes, str], method: str, timeout: Optional[float] = None, - **kwargs) -> Any: + return self.ask_message( + message=Message( + receiver=receiver, conversation_id=conversation_id, data=data, **kwargs + ), + timeout=timeout, + ) + + def interpret_rpc_response( + self, response_message: Message, extract_additional_payload: bool = False + ) -> Union[Any, tuple[Any, list[bytes]]]: + """Retrieve the return value of a RPC response and optionally the additional payload.""" + result = self.rpc_generator.get_result_from_response(response_message.payload[0]) + if extract_additional_payload: + return result, response_message.payload[1:] + else: + return result + + def ask_rpc( + self, + receiver: Union[bytes, str], + method: str, + timeout: Optional[float] = None, + additional_payload: Optional[Iterable[bytes]] = None, + extract_additional_payload: bool = False, + **kwargs, + ) -> Any: + """Send a JSON-RPC request (with method \\**kwargs) and return the response value.""" string = self.rpc_generator.build_request_str(method=method, **kwargs) - response = self.ask(receiver=receiver, data=string, message_type=MessageTypes.JSON, - timeout=timeout) - return self.interpret_rpc_response(response) + response = self.ask( + receiver=receiver, + data=string, + message_type=MessageTypes.JSON, + additional_payload=additional_payload, + timeout=timeout, + ) + return self.interpret_rpc_response( + response, extract_additional_payload=extract_additional_payload + ) class SubscriberProtocol(Protocol): diff --git a/pyleco/core/leco_protocols.py b/pyleco/core/leco_protocols.py index 865961a6..788217d8 100644 --- a/pyleco/core/leco_protocols.py +++ b/pyleco/core/leco_protocols.py @@ -105,7 +105,7 @@ def shut_down(self) -> None: ... class CoordinatorProtocol(ComponentProtocol, Protocol): - """A command protocol Coordinator""" + """A command protocol Coordinator.""" def sign_in(self) -> None: ... diff --git a/pyleco/core/message.py b/pyleco/core/message.py index c50b8a76..1bcb8036 100644 --- a/pyleco/core/message.py +++ b/pyleco/core/message.py @@ -24,7 +24,7 @@ from __future__ import annotations from json import JSONDecodeError -from typing import Any, Optional, Union +from typing import Any, Iterable, Optional, Union from . import VERSION_B @@ -64,6 +64,7 @@ def __init__(self, conversation_id: Optional[bytes] = None, message_id: Optional[bytes] = None, message_type: Union[MessageTypes, int] = MessageTypes.NOT_DEFINED, + additional_payload: Optional[Iterable[bytes]] = None, ) -> None: self.receiver = receiver.encode() if isinstance(receiver, str) else receiver self.sender = sender.encode() if isinstance(sender, str) else sender @@ -81,6 +82,8 @@ def __init__(self, self.payload = [] else: self.payload = [serialize_data(data)] + if additional_payload is not None: + self.payload.extend(additional_payload) @classmethod def from_frames(cls, version: bytes, receiver: bytes, sender: bytes, header: bytes, @@ -92,9 +95,8 @@ def from_frames(cls, version: bytes, receiver: bytes, sender: bytes, header: byt frames = socket.recv_multipart() message = Message.from_frames(*frames) """ - inst = cls(receiver, sender, header=header) + inst = cls(receiver, sender, header=header, additional_payload=payload) inst.version = version - inst.payload = list(payload) return inst def to_frames(self) -> list[bytes]: diff --git a/pyleco/directors/director.py b/pyleco/directors/director.py index c2bda6b8..8e6f9fc1 100644 --- a/pyleco/directors/director.py +++ b/pyleco/directors/director.py @@ -24,7 +24,7 @@ from __future__ import annotations import logging -from typing import Any, Optional, Sequence, Union +from typing import Any, Iterable, Optional, Sequence, Union from ..core.internal_protocols import CommunicatorProtocol from ..utils.communicator import Communicator @@ -88,10 +88,9 @@ def __exit__(self, exc_type, exc_value, exc_traceback) -> None: # Message handling def ask_message(self, actor: Optional[Union[bytes, str]] = None, data: Optional[Any] = None, **kwargs) -> Message: - cid0 = generate_conversation_id() actor = self._actor_check(actor) log.debug(f"Asking {actor!r} with message '{data}'.") - response = self.communicator.ask(actor, conversation_id=cid0, data=data, **kwargs) + response = self.communicator.ask(actor, data=data, **kwargs) log.debug(f"Data '{response.data}' received.") return response @@ -113,10 +112,23 @@ def _prepare_call_action_params(self, args: tuple[Any, ...], return params # Remote control synced - def ask_rpc(self, method: str, actor: Optional[Union[bytes, str]] = None, **kwargs) -> Any: + def ask_rpc( + self, + method: str, + actor: Optional[Union[bytes, str]] = None, + additional_payload: Optional[Iterable[bytes]] = None, + extract_additional_payload: bool = False, + **kwargs, + ) -> Any: """Remotely call the `method` procedure on the `actor` and return the return value.""" receiver = self._actor_check(actor) - return self.communicator.ask_rpc(receiver=receiver, method=method, **kwargs) + return self.communicator.ask_rpc( + receiver=receiver, + method=method, + additional_payload=additional_payload, + extract_additional_payload=extract_additional_payload, + **kwargs, + ) # Component def get_rpc_capabilities(self, actor: Optional[Union[bytes, str]] = None) -> dict: @@ -165,23 +177,48 @@ def call_action(self, action: str, *args, actor: Optional[Union[bytes, str]] = N return self.ask_rpc("call_action", action=action, actor=actor, **params) # Async methods: Just send, read later. - def send(self, actor: Optional[Union[bytes, str]] = None, data=None, **kwargs) -> bytes: + def send( + self, + actor: Optional[Union[bytes, str]] = None, + data=None, + additional_payload: Optional[Iterable[bytes]] = None, + **kwargs, + ) -> bytes: """Send a request and return the conversation_id.""" actor = self._actor_check(actor) cid0 = generate_conversation_id() - self.communicator.send(actor, conversation_id=cid0, data=data, **kwargs) + self.communicator.send( + actor, conversation_id=cid0, data=data, additional_payload=additional_payload, **kwargs + ) return cid0 - def ask_rpc_async(self, method: str, actor: Optional[Union[bytes, str]] = None, - **kwargs) -> bytes: + def ask_rpc_async( + self, + method: str, + actor: Optional[Union[bytes, str]] = None, + additional_payload: Optional[Iterable[bytes]] = None, + **kwargs, + ) -> bytes: """Send a rpc request, the response can be read later with :meth:`read_rpc_response`.""" string = self.generator.build_request_str(method=method, **kwargs) - return self.send(actor=actor, data=string, message_type=MessageTypes.JSON) - - def read_rpc_response(self, conversation_id: Optional[bytes] = None, **kwargs) -> Any: + return self.send( + actor=actor, + data=string, + message_type=MessageTypes.JSON, + additional_payload=additional_payload, + ) + + def read_rpc_response( + self, + conversation_id: Optional[bytes] = None, + extract_additional_payload: bool = False, + **kwargs, + ) -> Any: """Read the response value corresponding to a request with a certain `conversation_id`.""" response_message = self.communicator.read_message(conversation_id=conversation_id, **kwargs) - return self.communicator.interpret_rpc_response(response_message=response_message) + return self.communicator.interpret_rpc_response( + response_message=response_message, extract_additional_payload=extract_additional_payload + ) # Actor def get_parameters_async(self, parameters: Union[str, Sequence[str]], diff --git a/pyleco/management/test_tasks/test_task.py b/pyleco/management/test_tasks/test_task.py index cbe11000..62f9fe7f 100644 --- a/pyleco/management/test_tasks/test_task.py +++ b/pyleco/management/test_tasks/test_task.py @@ -8,7 +8,7 @@ from pyleco.actors.actor import Actor -class FakeInstrument: +class FakeInstrument: # pragma: no cover _prop1 = 5 def __init__(self): diff --git a/pyleco/test.py b/pyleco/test.py index 8d1da5fd..edbce43c 100644 --- a/pyleco/test.py +++ b/pyleco/test.py @@ -23,7 +23,7 @@ # from __future__ import annotations -from typing import Any, Optional, Sequence, Union +from typing import Any, Iterable, Optional, Sequence, Union from .core.message import Message from .core.internal_protocols import CommunicatorProtocol @@ -219,14 +219,26 @@ def __init__(self, remote_class, **kwargs): super().__init__(**kwargs) self.remote_class = remote_class - def ask_rpc(self, method: str, actor: Optional[Union[bytes, str]] = None, **kwargs) -> Any: + def ask_rpc( + self, + method: str, + actor: Optional[Union[bytes, str]] = None, + additional_payload: Optional[Iterable[bytes]] = None, + extract_additional_payload: bool = False, + **kwargs, + ) -> Any: assert hasattr(self.remote_class, method), f"Remote class does not have method '{method}'." self.method = method self.kwargs = kwargs return self.return_value - def ask_rpc_async(self, method: str, actor: Optional[Union[bytes, str]] = None, - **kwargs) -> bytes: + def ask_rpc_async( + self, + method: str, + actor: Optional[Union[bytes, str]] = None, + additional_payload: Optional[Iterable[bytes]] = None, + **kwargs, + ) -> bytes: assert hasattr(self.remote_class, method), f"Remote class does not have method '{method}'." self.method = method self.kwargs = kwargs diff --git a/pyleco/utils/data_publisher.py b/pyleco/utils/data_publisher.py index 46add4ca..04f726c9 100644 --- a/pyleco/utils/data_publisher.py +++ b/pyleco/utils/data_publisher.py @@ -25,7 +25,7 @@ from __future__ import annotations import logging import pickle -from typing import Any, Optional, Union +from typing import Any, Iterable, Optional, Union import zmq @@ -49,13 +49,15 @@ class DataPublisher: full_name: str - def __init__(self, - full_name: str, - host: str = "localhost", - port: int = PROXY_RECEIVING_PORT, - log: Optional[logging.Logger] = None, - context: Optional[zmq.Context] = None, - **kwargs) -> None: + def __init__( + self, + full_name: str, + host: str = "localhost", + port: int = PROXY_RECEIVING_PORT, + log: Optional[logging.Logger] = None, + context: Optional[zmq.Context] = None, + **kwargs, + ) -> None: if log is None: self.log = logging.getLogger(f"{__name__}.Publisher") else: @@ -87,17 +89,22 @@ def send_message(self, message: DataMessage) -> None: """Send a data protocol message.""" self.socket.send_multipart(message.to_frames()) - def send_data(self, data: Any, - topic: Optional[Union[bytes, str]] = None, - conversation_id: Optional[bytes] = None, - message_type: Union[MessageTypes, int] = MessageTypes.NOT_DEFINED, - ) -> None: + def send_data( + self, + data: Any, + topic: Optional[Union[bytes, str]] = None, + conversation_id: Optional[bytes] = None, + message_type: Union[MessageTypes, int] = MessageTypes.NOT_DEFINED, + additional_payload: Optional[Iterable[bytes]] = None, + ) -> None: """Send the `data` via the data protocol.""" - message = DataMessage(topic=topic or self.full_name, - data=data, - conversation_id=conversation_id, - message_type=message_type - ) + message = DataMessage( + topic=topic or self.full_name, + data=data, + conversation_id=conversation_id, + message_type=message_type, + additional_payload=additional_payload, + ) self.send_message(message) def send_legacy(self, data: dict[str, Any]) -> None: diff --git a/pyleco/utils/listener.py b/pyleco/utils/listener.py index 0d93fe4b..0bdbe8f7 100644 --- a/pyleco/utils/listener.py +++ b/pyleco/utils/listener.py @@ -22,10 +22,11 @@ # THE SOFTWARE. # +from __future__ import annotations import logging from threading import Thread, Event from time import sleep -from typing import Callable, Optional +from typing import Any, Callable, Optional, Union from ..core import PROXY_SENDING_PORT, COORDINATOR_PORT from .pipe_handler import PipeHandler, CommunicatorPipe @@ -132,13 +133,31 @@ def get_communicator(self, **kwargs) -> CommunicatorPipe: kwargs.setdefault("timeout", self.timeout) return self.message_handler.get_communicator(**kwargs) - def register_rpc_method(self, method: Callable, **kwargs) -> None: + def register_rpc_method(self, method: Callable[..., Any], **kwargs) -> None: """Register a method for calling with the current message handler. If you restart the listening, you have to register the method anew. """ self.message_handler.register_rpc_method(method=method, **kwargs) + def register_binary_rpc_method( + self, + method: Callable[..., Union[Any, tuple[Any, list[bytes]]]], + accept_binary_input: bool = False, + return_binary_output: bool = False, + **kwargs, + ) -> None: + """Register a binary method for calling with the current message handler. + + If you restart the listening, you have to register the method anew. + """ + self.message_handler.register_binary_rpc_method( + method=method, + accept_binary_input=accept_binary_input, + return_binary_output=return_binary_output, + **kwargs, + ) + def stop_listen(self) -> None: """Stop the listener Thread.""" try: diff --git a/pyleco/utils/message_handler.py b/pyleco/utils/message_handler.py index 44781454..ee9e0c9a 100644 --- a/pyleco/utils/message_handler.py +++ b/pyleco/utils/message_handler.py @@ -23,10 +23,11 @@ # from __future__ import annotations +from functools import wraps from json import JSONDecodeError import logging import time -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, TypeVar import zmq @@ -47,6 +48,9 @@ heartbeat_interval = 10 # s +ReturnValue = TypeVar("ReturnValue") + + class MessageHandler(BaseCommunicator, ExtendedComponentProtocol): """Maintain connection to the Coordinator and listen to incoming messages. @@ -66,6 +70,9 @@ class MessageHandler(BaseCommunicator, ExtendedComponentProtocol): name: str + current_message: Message + additional_response_payload: Optional[list[bytes]] = None + def __init__( self, name: str, @@ -131,10 +138,81 @@ def setup_socket(self, host: str, port: int, protocol: str, context: zmq.Context self.log.info(f"MessageHandler connecting to {host}:{port}") self.socket.connect(f"{protocol}://{host}:{port}") - def register_rpc_method(self, method: Callable, **kwargs) -> None: + def register_rpc_method(self, method: Callable[..., Any], **kwargs) -> None: """Register a method to be available via rpc calls.""" self.rpc.method(**kwargs)(method) + def _handle_binary_return_value( + self, return_value: tuple[ReturnValue, list[bytes]] + ) -> ReturnValue: + self.additional_response_payload = return_value[1] + return return_value[0] + + @staticmethod + def _pass_through(return_value: ReturnValue) -> ReturnValue: + return return_value + + def _generate_binary_capable_method( + self, + method: Callable[..., Union[ReturnValue, tuple[ReturnValue, list[bytes]]]], + accept_binary_input: bool = False, + return_binary_output: bool = False, + ) -> Callable[..., ReturnValue]: + returner = self._handle_binary_return_value if return_binary_output else self._pass_through + if accept_binary_input is True: + + @wraps(method) + def modified_method(*args, **kwargs) -> ReturnValue: # type: ignore + if args: + args_l = list(args) + if args_l[-1] is None: + args_l[-1] = self.current_message.payload[1:] + else: + args_l.append(self.current_message.payload[1:]) + args = args_l # type: ignore[assignment] + else: + kwargs["additional_payload"] = self.current_message.payload[1:] + return_value = method( + *args, **kwargs + ) + return returner(return_value=return_value) # type: ignore + else: + + @wraps(method) + def modified_method(*args, **kwargs) -> ReturnValue: + return_value = method(*args, **kwargs) + return returner(return_value=return_value) # type: ignore + + doc_addition = ( + f"(binary{' input' * accept_binary_input}{' output' * return_binary_output} method)" + ) + try: + modified_method.__doc__ += "\n" + doc_addition # type: ignore[operator] + except TypeError: + modified_method.__doc__ = doc_addition + return modified_method # type: ignore + + def register_binary_rpc_method( + self, + method: Callable[..., Union[Any, tuple[Any, list[bytes]]]], + accept_binary_input: bool = False, + return_binary_output: bool = False, + **kwargs, + ) -> None: + """Register a method which accepts binary input and/or returns binary values. + + :param accept_binary_input: the method must accept the additional payload as an + `additional_payload=None` parameter (default value must be present as `None`!). + :param return_binary_output: the method must return a tuple of a JSON-able python object + (e.g. `None`) and of a list of bytes objects, to be sent as additional payload. + """ + modified_method = self._generate_binary_capable_method( + method=method, + accept_binary_input=accept_binary_input, + return_binary_output=return_binary_output, + ) + self.register_rpc_method(modified_method, **kwargs) + def register_rpc_methods(self) -> None: """Register methods for RPC.""" self.register_rpc_method(self.shut_down) @@ -251,6 +329,8 @@ def handle_json_request(self, message: Message) -> None: self.send_message(response) def process_json_message(self, message: Message) -> Message: + self.current_message = message + self.additional_response_payload = None self.log.info(f"Handling commands of {message}.") reply = self.rpc.process_request(message.payload[0]) response = Message( @@ -258,6 +338,7 @@ def process_json_message(self, message: Message) -> Message: conversation_id=message.conversation_id, message_type=MessageTypes.JSON, data=reply, + additional_payload=self.additional_response_payload ) return response diff --git a/tests/acceptance_tests/test_director_actor.py b/tests/acceptance_tests/test_director_actor.py index 74e23ce3..94c39315 100644 --- a/tests/acceptance_tests/test_director_actor.py +++ b/tests/acceptance_tests/test_director_actor.py @@ -22,6 +22,7 @@ # THE SOFTWARE. # +from __future__ import annotations import logging import threading from time import sleep @@ -69,6 +70,23 @@ def triple(self, factor: float = 1, factor2: float = 1) -> float: def start_actor(event: threading.Event): actor = Actor("actor", FakeInstrument, port=PORT) + + def binary_method_manually() -> None: + """Receive binary data and return it. Do all the binary things manually.""" + payload = actor.current_message.payload[1:] + try: + actor.additional_response_payload = [payload[0] * 2] + except IndexError: + pass + + def binary_method_created(additional_payload: list[bytes]) -> tuple[None, list[bytes]]: + """Receive binary data and return it. Create binary method by registering it.""" + return None, [additional_payload[0] * 2] + + actor.register_rpc_method(binary_method_manually) + actor.register_binary_rpc_method( + binary_method_created, accept_binary_input=True, return_binary_output=True + ) actor.connect() actor.rpc.method()(actor.device.triple) actor.register_device_method(actor.device.triple) @@ -133,3 +151,17 @@ def test_method_via_rpc2(director: Director): def test_device_method_via_rpc(director: Director): assert director.ask_rpc(method="device.triple", factor=5) == 15 + + +def test_binary_data_transfer(director: Director): + assert director.ask_rpc( + method="binary_method_manually", + additional_payload=[b"123"], + extract_additional_payload=True, + ) == (None, [b"123123"]) + + +def test_binary_data_transfer_created(director: Director): + assert director.ask_rpc( + method="binary_method_created", additional_payload=[b"123"], extract_additional_payload=True + ) == (None, [b"123123"]) diff --git a/tests/core/test_data_message.py b/tests/core/test_data_message.py index ceddb243..d443f781 100644 --- a/tests/core/test_data_message.py +++ b/tests/core/test_data_message.py @@ -54,6 +54,14 @@ def test_header_param_incompatible_with_header_element_params(self, key, value): with pytest.raises(ValueError, match="header"): DataMessage(topic="topic", header=b"whatever", **{key: value}) + def test_additional_payload(self): + message = DataMessage("topic", data=b"0", additional_payload=[b"1", b"2"]) + assert message.payload == [b"0", b"1", b"2"] + + def test_additional_payload_without_data(self): + message = DataMessage("topic", additional_payload=[b"1", b"2"]) + assert message.payload == [b"1", b"2"] + def test_data_message_str_topic(): assert DataMessage(topic="topic").topic == b"topic" diff --git a/tests/core/test_internal_protocols.py b/tests/core/test_internal_protocols.py index 7a345954..acd4c74e 100644 --- a/tests/core/test_internal_protocols.py +++ b/tests/core/test_internal_protocols.py @@ -84,6 +84,51 @@ def test_error(self, communicator: FakeCommunicator): with pytest.raises(JSONRPCError): communicator.interpret_rpc_response(message) + def test_json_binary_response(self, communicator: FakeCommunicator): + message = Message( + receiver="rec", + data={"jsonrpc": "2.0", "result": None, "id": 7}, + additional_payload=[b"abcd", b"efgh"], + ) + assert communicator.interpret_rpc_response(message, extract_additional_payload=True) == ( + None, + [ + b"abcd", + b"efgh", + ], + ) + + def test_ignore_additional_payload_if_not_desired(self, communicator: FakeCommunicator): + message = Message( + receiver="rec", + data={"jsonrpc": "2.0", "result": None, "id": 7}, + additional_payload=[b"abcd"], + ) + assert ( + communicator.interpret_rpc_response(message, extract_additional_payload=False) is None + ) + + def test_without_additional_payload_return_empty_list(self, communicator: FakeCommunicator): + message = Message( + receiver="rec", + data={"jsonrpc": "2.0", "result": None, "id": 7}, + ) + assert communicator.interpret_rpc_response(message, extract_additional_payload=True) == ( + None, + [], + ) + + def test_json_value_and_binary_payload(self, communicator: FakeCommunicator): + message = Message( + receiver="rec", + data={"jsonrpc": "2.0", "result": 6, "id": 7}, + additional_payload=[b"abcd"], + ) + assert communicator.interpret_rpc_response(message, extract_additional_payload=True) == ( + 6, + [b"abcd"], + ) + class Test_ask_rpc: response = Message(receiver="communicator", sender="rec", conversation_id=cid, @@ -112,6 +157,27 @@ def test_sent(self, communicator_asked: FakeCommunicator): "params": {'par1': 5}, })] + def test_sent_with_additional_payload(self, communicator_asked: FakeCommunicator): + communicator_asked.ask_rpc( + receiver="rec", method="test_method", par1=5, additional_payload=[b"12345"] + ) + sent = communicator_asked._s[0] + assert communicator_asked._s == [ + Message( + receiver="rec", + sender="communicator", + conversation_id=sent.conversation_id, + message_type=MessageTypes.JSON, + data={ + "jsonrpc": "2.0", + "method": "test_method", + "id": 1, + "params": {"par1": 5}, + }, + additional_payload=[b"12345"], + ) + ] + def test_read(self, communicator_asked: FakeCommunicator): result = communicator_asked.ask_rpc(receiver="rec", method="test_method", par1=5) assert result == 5 diff --git a/tests/core/test_message.py b/tests/core/test_message.py index 71b1c6a4..17a940e4 100644 --- a/tests/core/test_message.py +++ b/tests/core/test_message.py @@ -84,6 +84,14 @@ def test_message_data_str_to_binary_data(self): message = Message(b"rec", data="some string") assert message.payload[0] == b"some string" + def test_additional_binary_data(self): + message = Message(b"rec", data=b"0", additional_payload=[b"1", b"2"]) + assert message.payload == [b"0", b"1", b"2"] + + def test_additional_payload_without_data(self): + message = Message(b"rec", additional_payload=[b"1", b"2"]) + assert message.payload == [b"1", b"2"] + @pytest.mark.parametrize("key, value", (("conversation_id", b"content"), ("message_id", b"mid"), ("message_type", 7), diff --git a/tests/directors/test_director.py b/tests/directors/test_director.py index 89a523e4..e0a2b826 100644 --- a/tests/directors/test_director.py +++ b/tests/directors/test_director.py @@ -77,6 +77,19 @@ def test_default_actor(self, director: Director): assert director._actor_check("") == "actor" +def test_ask_message(director: Director): + rec = Message("director", "actor", conversation_id=cid) + director.communicator._r = [rec] # type: ignore + result = director.ask_message() + assert result == rec + sent = director.communicator._s[0] # type: ignore + assert sent == Message( + "actor", + "director", + conversation_id=cid, + ) + + def test_get_rpc_capabilities(director: Director): data = {"name": "actor", "methods": []} director.communicator._r = [ # type: ignore @@ -123,6 +136,23 @@ def test_read_rpc_response(director: Director): assert director.read_rpc_response(conversation_id=cid) == 7.5 +def test_read_binary_rpc_response(director: Director): + director.communicator._r = [ # type: ignore + Message( + "director", + "actor", + conversation_id=cid, + message_type=MessageTypes.JSON, + data={"id": 1, "result": None, "jsonrpc": "2.0"}, + additional_payload=[b"123"], + ) + ] + assert director.read_rpc_response(conversation_id=cid, extract_additional_payload=True) == ( + None, + [b"123"], + ) + + def test_get_properties_async(director: Director): properties = ["a", "some"] cid = director.get_parameters_async(parameters=properties) diff --git a/tests/utils/test_data_publisher.py b/tests/utils/test_data_publisher.py index d783c33b..c6f3575d 100644 --- a/tests/utils/test_data_publisher.py +++ b/tests/utils/test_data_publisher.py @@ -59,6 +59,13 @@ def test_call_publisher_sends(publisher: DataPublisher): assert message.payload[0] == b"data" +def test_send_data(publisher: DataPublisher): + publisher.send_data( + data=b"data", topic=b"topic", conversation_id=b"cid", additional_payload=[b"1"] + ) + assert publisher.socket._s == [[b"topic", b"cid\x00", b"data", b"1"]] + + def test_send_message(publisher: DataPublisher): message = DataMessage.from_frames(b"topic", b"header", b"data") publisher.send_message(message=message) diff --git a/tests/utils/test_message_handler.py b/tests/utils/test_message_handler.py index 553e8f5c..7ceddc31 100644 --- a/tests/utils/test_message_handler.py +++ b/tests/utils/test_message_handler.py @@ -453,8 +453,12 @@ def test_handle_corrupted_message(self, handler: MessageHandler, def test_handle_undecodable_message(self, handler: MessageHandler, caplog: pytest.LogCaptureFixture): """An invalid message should not cause the message handler to crash.""" - message = Message(b"N3.handler", b"N3.COORDINATOR", message_type=MessageTypes.JSON) - message.payload = [b"()"] + message = Message( + b"N3.handler", + b"N3.COORDINATOR", + message_type=MessageTypes.JSON, + additional_payload=[b"()"], + ) handler.socket._r = [message.to_frames()] # type: ignore handler.read_and_handle_message() assert caplog.records[-1].msg.startswith("Could not decode") @@ -494,6 +498,171 @@ def test_handle_json_not_request(self, handler: MessageHandler): assert error.message == INVALID_REQUEST.message +class Test_process_json_message_with_created_binary: + payload_in: list[bytes] + payload_out: list[bytes] + + @pytest.fixture( + params=( + # normally created binary method + {"method": "do_binary", "params": [5]}, # with a list + {"method": "do_binary", "params": {"data": 5}}, # a dictionary + # manually created binary method + {"method": "do_binary_manually", "params": [5]}, + {"method": "do_binary_manually", "params": {"data": 5}}, + ), + ids=( + "created, list", + "created, dict", + "manual, list", + "manual, dict", + ), + ) + def data(self, request): + """Create a request with a list and a dict of other parameters.""" + d = {"jsonrpc": "2.0", "id": 8} + d.update(request.param) + return d + + @pytest.fixture + def handler_b(self, handler: MessageHandler): + test_class = self + class SpecialHandler(MessageHandler): + def do_binary_manually(self, data: int) -> int: + test_class.payload_in = self.current_message.payload[1:] + self.additional_response_payload = test_class.payload_out + return data + + def do_binary( + self, data: int, additional_payload: Optional[list[bytes]] = None + ) -> tuple[int, list[bytes]]: + test_class.payload_in = additional_payload # type: ignore + return data, test_class.payload_out + + handler = SpecialHandler(name=handler_name.split(".")[1], context=FakeContext()) # type: ignore + handler.namespace = handler_name.split(".")[0] + handler.stop_event = SimpleEvent() + handler.timeout = 0.1 + + handler.register_rpc_method(handler.do_binary_manually) + handler.register_binary_rpc_method( + handler.do_binary, accept_binary_input=True, return_binary_output=True + ) + return handler + + def test_message_stored(self, handler_b: MessageHandler, data): + m_in = Message("abc", data=data, message_type=MessageTypes.JSON) + handler_b.process_json_message(m_in) + assert handler_b.current_message == m_in + + def test_empty_additional_payload(self, handler_b: MessageHandler, data): + m_in = Message("abc", data=data, message_type=MessageTypes.JSON) + handler_b.process_json_message(m_in) + assert handler_b.additional_response_payload is None + + def test_binary_payload_available(self, handler_b: MessageHandler, data): + m_in = Message( + "abc", data=data, message_type=MessageTypes.JSON, additional_payload=[b"def"] + ) + self.payload_out = [] + handler_b.process_json_message(m_in) + assert self.payload_in == [b"def"] + + def test_binary_payload_sent(self, handler_b: MessageHandler, data): + m_in = Message("abc", data=data, message_type=MessageTypes.JSON) + self.payload_out = [b"ghi"] + response = handler_b.process_json_message(m_in) + assert response.payload[1:] == [b"ghi"] + assert response.data == {"jsonrpc": "2.0", "id": 8, "result": 5} + + +def test_handle_binary_return_value(handler: MessageHandler): + payload = [b"abc", b"def"] + result = handler._handle_binary_return_value((None, payload)) + assert result is None + assert handler.additional_response_payload == payload + + +class Test_generate_binary_method: + @pytest.fixture + def binary_method(self): + def binary_method(index: int, additional_payload: list[bytes]) -> tuple[None, list[bytes]]: + """Docstring of binary method.""" + return None, [additional_payload[index]] + return binary_method + + @pytest.fixture(params=(True, False)) + def modified_binary_method(self, handler: MessageHandler, binary_method, request): + handler.current_message = Message( + "rec", "send", data=b"", additional_payload=[b"0", b"1", b"2", b"3"] + ) + self._accept_binary_input = request.param + mod = handler._generate_binary_capable_method( + binary_method, accept_binary_input=self._accept_binary_input, return_binary_output=True + ) + self.handler = handler + return mod + + def test_name(self, binary_method, modified_binary_method): + assert modified_binary_method.__name__ == binary_method.__name__ + + def test_docstring(self, modified_binary_method, binary_method): + doc_addition = ( + "(binary input output method)" + if self._accept_binary_input + else "(binary output method)" + ) + assert modified_binary_method.__doc__ == binary_method.__doc__ + "\n" + doc_addition + + @pytest.mark.parametrize( + "input, output, string", + ( + (False, False, "(binary method)"), + (True, False, "(binary input method)"), + (False, True, "(binary output method)"), + (True, True, "(binary input output method)"), + ), + ) + def test_docstring_without_original_docstring( + self, handler: MessageHandler, input, output, string + ): + def binary_method(additional_payload): + return 7 + mod = handler._generate_binary_capable_method( + binary_method, accept_binary_input=input, return_binary_output=output + ) + assert mod.__doc__ == string + + def test_annotation(self, modified_binary_method, binary_method): + assert modified_binary_method.__annotations__ == binary_method.__annotations__ + + def test_functionality_kwargs(self, modified_binary_method): + if self._accept_binary_input: + assert modified_binary_method(index=1) is None + else: + assert ( + modified_binary_method(index=1, additional_payload=[b"0", b"1", b"2", b"3"]) is None + ) + assert self.handler.additional_response_payload == [b"1"] + + def test_functionality_args(self, modified_binary_method): + if self._accept_binary_input: + assert modified_binary_method(1) is None + else: + assert modified_binary_method(1, [b"0", b"1", b"2", b"3"]) is None + assert self.handler.additional_response_payload == [b"1"] + + def test_binary_input_from_message(self, handler: MessageHandler): + handler.current_message = Message("rec", "send", data=b"", additional_payload=[b"0"]) + + def binary_method(additional_payload = None): + return 7 + mod = handler._generate_binary_capable_method( + binary_method, accept_binary_input=True, return_binary_output=False + ) + assert mod() == 7 + + class Test_listen: @pytest.fixture def handler_l(self, handler: MessageHandler, fake_cid_generation):