From 92fb6d107088e3d69dbb7825ff7c4dfb26c0cdd5 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Mon, 20 Nov 2023 08:42:49 -0800 Subject: [PATCH] Add some unit tests Signed-off-by: Daniel Widdis --- .../transport/extension_transport_request.py | 6 +-- .../protobuf/test_extension_request_proto.py | 36 +++++++++++++++ .../test_extension_transport_request.py | 45 +++++++++++++++++++ 3 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 tests/protobuf/test_extension_request_proto.py create mode 100644 tests/transport/test_extension_transport_request.py diff --git a/src/opensearch_sdk_py/transport/extension_transport_request.py b/src/opensearch_sdk_py/transport/extension_transport_request.py index fc01e82..6f0ae03 100644 --- a/src/opensearch_sdk_py/transport/extension_transport_request.py +++ b/src/opensearch_sdk_py/transport/extension_transport_request.py @@ -9,7 +9,6 @@ # https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/extensions/ExtensionRequest.java -from typing import Optional from opensearch_sdk_py.protobuf.ExtensionRequestProto_pb2 import ExtensionRequest from opensearch_sdk_py.transport.request_type import RequestType @@ -22,13 +21,12 @@ class ExtensionTransportRequest(TransportRequest): def __init__( self, request_type: "RequestType", - unique_id: Optional[str] = None, + unique_id: str = "", ) -> None: super().__init__() self.er = ExtensionRequest() self.er.requestType = request_type.value - if unique_id is not None: - self.er.identity.uniqueId = unique_id + self.er.identity.uniqueId = unique_id def read_from(self, input: StreamInput) -> "ExtensionTransportRequest": super().read_from(input) diff --git a/tests/protobuf/test_extension_request_proto.py b/tests/protobuf/test_extension_request_proto.py new file mode 100644 index 0000000..c51e4d8 --- /dev/null +++ b/tests/protobuf/test_extension_request_proto.py @@ -0,0 +1,36 @@ +# +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +# +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# + +import unittest + +from opensearch_sdk_py.protobuf import ExtensionRequestProto_pb2 +from opensearch_sdk_py.transport.request_type import RequestType + + +class TestExtensionRequestProto_pb2(unittest.TestCase): + def test_extension_request(self) -> None: + request = ExtensionRequestProto_pb2.ExtensionRequest() + request.requestType = RequestType.GET_SETTINGS.value + request.identity.uniqueId = "test" + serialized_str = request.SerializeToString() + + parsed_request = ExtensionRequestProto_pb2.ExtensionRequest() + parsed_request.ParseFromString(serialized_str) + self.assertEqual(parsed_request.requestType, RequestType.GET_SETTINGS.value) + self.assertEqual(parsed_request.identity.uniqueId, "test") + + def test_extension_request_no_id(self) -> None: + request = ExtensionRequestProto_pb2.ExtensionRequest() + request.requestType = RequestType.GET_SETTINGS.value + serialized_str = request.SerializeToString() + + parsed_request = ExtensionRequestProto_pb2.ExtensionRequest() + parsed_request.ParseFromString(serialized_str) + self.assertEqual(parsed_request.requestType, RequestType.GET_SETTINGS.value) + self.assertEqual(parsed_request.identity.uniqueId, "") diff --git a/tests/transport/test_extension_transport_request.py b/tests/transport/test_extension_transport_request.py new file mode 100644 index 0000000..3fcd15d --- /dev/null +++ b/tests/transport/test_extension_transport_request.py @@ -0,0 +1,45 @@ +# +# Copyright OpenSearch Contributors +# SPDX-License-Identifier: Apache-2.0 +# +# The OpenSearch Contributors require contributions made to +# this file be licensed under the Apache-2.0 license or a +# compatible open source license. +# + +import unittest + +from opensearch_sdk_py.transport.extension_transport_request import ExtensionTransportRequest +from opensearch_sdk_py.transport.request_type import RequestType +from opensearch_sdk_py.transport.stream_input import StreamInput +from opensearch_sdk_py.transport.stream_output import StreamOutput + + +class TestExtensionTransportRequest(unittest.TestCase): + def test_extension_transport_request(self) -> None: + etr = ExtensionTransportRequest(RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS, "test") + self.assertEqual(etr.er.requestType, RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS.value) + self.assertEqual(etr.er.identity.uniqueId, "test") + + out = StreamOutput() + etr.write_to(out) + etr = ExtensionTransportRequest(RequestType.GET_SETTINGS) + self.assertEqual(etr.er.requestType, RequestType.GET_SETTINGS.value) + self.assertEqual(etr.er.identity.uniqueId, "") + etr.read_from(input=StreamInput(out.getvalue())) + self.assertEqual(etr.er.requestType, RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS.value) + self.assertEqual(etr.er.identity.uniqueId, "test") + + def test_extension_transport_request_no_id(self) -> None: + etr = ExtensionTransportRequest(RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS) + self.assertEqual(etr.er.requestType, RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS.value) + self.assertEqual(etr.er.identity.uniqueId, "") + + out = StreamOutput() + etr.write_to(out) + etr = ExtensionTransportRequest(RequestType.GET_SETTINGS) + self.assertEqual(etr.er.requestType, RequestType.GET_SETTINGS.value) + self.assertEqual(etr.er.identity.uniqueId, "") + etr.read_from(input=StreamInput(out.getvalue())) + self.assertEqual(etr.er.requestType, RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS.value) + self.assertEqual(etr.er.identity.uniqueId, "")