From 1f03a8de02459b6fae5b4a8a22dbe3fb1b8012d3 Mon Sep 17 00:00:00 2001 From: Benedikt Burger <67148916+BenediktBurger@users.noreply.github.com> Date: Wed, 20 Nov 2024 11:44:59 +0100 Subject: [PATCH] Improve inner workings of RPCServer. --- pyleco/json_utils/rpc_server_definition.py | 58 ++++++++++------------ tests/json_utils/test_rpc_server.py | 40 ++++++++++++++- 2 files changed, 66 insertions(+), 32 deletions(-) diff --git a/pyleco/json_utils/rpc_server_definition.py b/pyleco/json_utils/rpc_server_definition.py index f5b015ae..8440f90a 100644 --- a/pyleco/json_utils/rpc_server_definition.py +++ b/pyleco/json_utils/rpc_server_definition.py @@ -28,7 +28,7 @@ from typing import Any, Callable, Optional, Union from .errors import INTERNAL_ERROR, SERVER_ERROR, INVALID_REQUEST -from .json_objects import ResultResponse, ErrorResponse, DataError +from .json_objects import ResultResponse, ErrorResponse, DataError, ResponseType, ResponseBatch log = logging.getLogger(__name__) @@ -50,15 +50,8 @@ def __init__( self.method(name="rpc.discover")(self.discover) def method(self, name: Optional[str] = None, **kwargs) -> Callable[[Callable], None]: - if name is None: - - def method_registrar(method: Callable) -> None: - return self._register_method(name=method.__name__, method=method) - else: - - def method_registrar(method: Callable) -> None: - return self._register_method(name=name, method=method) - + def method_registrar(method: Callable) -> None: + return self._register_method(name=name or method.__name__, method=method) return method_registrar def _register_method(self, name: str, method: Callable) -> None: @@ -67,31 +60,34 @@ def _register_method(self, name: str, method: Callable) -> None: def process_request(self, data: Union[bytes, str]) -> Optional[str]: try: json_data = json.loads(data) - if isinstance(json_data, list): - results = [] - for element in json_data: - result = self._process_single_request(element) - if result is not None: - results.append(result.model_dump()) - if results: - return json.dumps(results, separators=(",", ":")) - else: - return None - elif isinstance(json_data, dict): - result = self._process_single_request(json_data) - if result: - return result.model_dump_json() - else: - return None - else: - return ErrorResponse( - id=None, - error=DataError.from_error(INVALID_REQUEST, json_data), - ).model_dump_json() + result = self.process_request_object(json_data=json_data) + return result.model_dump_json() if result else None except Exception as exc: log.exception(f"{type(exc).__name__}:", exc_info=exc) return ErrorResponse(id=None, error=INTERNAL_ERROR).model_dump_json() + def process_request_object( + self, json_data: object + ) -> Optional[Union[ResponseType, ResponseBatch]]: + result: Optional[Union[ResponseType, ResponseBatch]] + if isinstance(json_data, list): + result = ResponseBatch() + for element in json_data: + result_element = self._process_single_request(element) + if result_element is not None: + result.append(result_element) + elif isinstance(json_data, dict): + result = self._process_single_request(json_data) + else: + result = ErrorResponse( + id=None, + error=DataError.from_error(INVALID_REQUEST, json_data), + ) + if result: + return result + else: + return None + def _process_single_request( self, request: dict[str, Any] ) -> Union[ResultResponse, ErrorResponse, None]: diff --git a/tests/json_utils/test_rpc_server.py b/tests/json_utils/test_rpc_server.py index bb2fa9d5..be4e09f1 100644 --- a/tests/json_utils/test_rpc_server.py +++ b/tests/json_utils/test_rpc_server.py @@ -35,6 +35,7 @@ ResultResponse, ErrorResponse, DataError, + ResponseBatch, ) from pyleco.json_utils.errors import ( ServerError, @@ -245,7 +246,7 @@ def test_batch_entry_notification(self, rpc_server_local: RPCServer): {"jsonrpc": "2.0", "method": "simple"}, {"jsonrpc": "2.0", "method": "simple", "id": 4}, ] - result = json.loads(rpc_server_local.process_request(json.dumps(requests))) + result = json.loads(rpc_server_local.process_request(json.dumps(requests))) # type: ignore assert result == [{"jsonrpc": "2.0", "result": 7, "id": 4}] def test_batch_of_notifications(self, rpc_server_local: RPCServer): @@ -264,3 +265,40 @@ def test_notification(self, rpc_server_local: RPCServer): ] result = rpc_server_local.process_request(json.dumps(requests)) assert result is None + + +class Test_process_request_object: + def test_invalid_request(self, rpc_server_local: RPCServer): + result = rpc_server_local.process_request_object(7) + assert ( + result + == ErrorResponse( + id=None, error=DataError.from_error(INVALID_REQUEST, 7) + ) + ) + + def test_batch_entry_notification(self, rpc_server_local: RPCServer): + """A notification (request without id) shall not return anything.""" + requests = [ + {"jsonrpc": "2.0", "method": "simple"}, + {"jsonrpc": "2.0", "method": "simple", "id": 4}, + ] + result = rpc_server_local.process_request_object(requests) + assert result == ResponseBatch([ResultResponse(4, 7)]) + + def test_batch_of_notifications(self, rpc_server_local: RPCServer): + """A notification (request without id) shall not return anything.""" + requests = [ + {"jsonrpc": "2.0", "method": "simple"}, + {"jsonrpc": "2.0", "method": "simple"}, + ] + result = rpc_server_local.process_request_object(requests) + assert result is None + + def test_notification(self, rpc_server_local: RPCServer): + """A notification (request without id) shall not return anything.""" + requests = [ + {"jsonrpc": "2.0", "method": "simple"}, + ] + result = rpc_server_local.process_request_object(requests) + assert result is None