Skip to content

Commit

Permalink
Add some unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Nov 20, 2023
1 parent f9b4900 commit 92fb6d1
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

# https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/extensions/ExtensionRequest.java

from typing import Optional

from opensearch_sdk_py.protobuf.ExtensionRequestProto_pb2 import ExtensionRequest
from opensearch_sdk_py.transport.request_type import RequestType
Expand All @@ -22,13 +21,12 @@ class ExtensionTransportRequest(TransportRequest):
def __init__(
self,
request_type: "RequestType",
unique_id: Optional[str] = None,
unique_id: str = "",
) -> None:
super().__init__()
self.er = ExtensionRequest()
self.er.requestType = request_type.value
if unique_id is not None:
self.er.identity.uniqueId = unique_id
self.er.identity.uniqueId = unique_id

def read_from(self, input: StreamInput) -> "ExtensionTransportRequest":
super().read_from(input)
Expand Down
36 changes: 36 additions & 0 deletions tests/protobuf/test_extension_request_proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#

import unittest

from opensearch_sdk_py.protobuf import ExtensionRequestProto_pb2
from opensearch_sdk_py.transport.request_type import RequestType


class TestExtensionRequestProto_pb2(unittest.TestCase):
def test_extension_request(self) -> None:
request = ExtensionRequestProto_pb2.ExtensionRequest()
request.requestType = RequestType.GET_SETTINGS.value
request.identity.uniqueId = "test"
serialized_str = request.SerializeToString()

parsed_request = ExtensionRequestProto_pb2.ExtensionRequest()
parsed_request.ParseFromString(serialized_str)
self.assertEqual(parsed_request.requestType, RequestType.GET_SETTINGS.value)
self.assertEqual(parsed_request.identity.uniqueId, "test")

def test_extension_request_no_id(self) -> None:
request = ExtensionRequestProto_pb2.ExtensionRequest()
request.requestType = RequestType.GET_SETTINGS.value
serialized_str = request.SerializeToString()

parsed_request = ExtensionRequestProto_pb2.ExtensionRequest()
parsed_request.ParseFromString(serialized_str)
self.assertEqual(parsed_request.requestType, RequestType.GET_SETTINGS.value)
self.assertEqual(parsed_request.identity.uniqueId, "")
45 changes: 45 additions & 0 deletions tests/transport/test_extension_transport_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright OpenSearch Contributors
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#

import unittest

from opensearch_sdk_py.transport.extension_transport_request import ExtensionTransportRequest
from opensearch_sdk_py.transport.request_type import RequestType
from opensearch_sdk_py.transport.stream_input import StreamInput
from opensearch_sdk_py.transport.stream_output import StreamOutput


class TestExtensionTransportRequest(unittest.TestCase):
def test_extension_transport_request(self) -> None:
etr = ExtensionTransportRequest(RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS, "test")
self.assertEqual(etr.er.requestType, RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS.value)
self.assertEqual(etr.er.identity.uniqueId, "test")

out = StreamOutput()
etr.write_to(out)
etr = ExtensionTransportRequest(RequestType.GET_SETTINGS)
self.assertEqual(etr.er.requestType, RequestType.GET_SETTINGS.value)
self.assertEqual(etr.er.identity.uniqueId, "")
etr.read_from(input=StreamInput(out.getvalue()))
self.assertEqual(etr.er.requestType, RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS.value)
self.assertEqual(etr.er.identity.uniqueId, "test")

def test_extension_transport_request_no_id(self) -> None:
etr = ExtensionTransportRequest(RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS)
self.assertEqual(etr.er.requestType, RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS.value)
self.assertEqual(etr.er.identity.uniqueId, "")

out = StreamOutput()
etr.write_to(out)
etr = ExtensionTransportRequest(RequestType.GET_SETTINGS)
self.assertEqual(etr.er.requestType, RequestType.GET_SETTINGS.value)
self.assertEqual(etr.er.identity.uniqueId, "")
etr.read_from(input=StreamInput(out.getvalue()))
self.assertEqual(etr.er.requestType, RequestType.REQUEST_EXTENSION_ENVIRONMENT_SETTINGS.value)
self.assertEqual(etr.er.identity.uniqueId, "")

0 comments on commit 92fb6d1

Please sign in to comment.