Skip to content

Commit

Permalink
Refactor as_json in asyncio client (#370)
Browse files Browse the repository at this point in the history
* Refactor as_json in asyncio client

* Add unit testing for asyncio client

* Remove the testing since it already exists

* Fix up
  • Loading branch information
Tabrizian authored Jul 28, 2023
1 parent 5eaca7a commit f7c45d3
Showing 1 changed file with 17 additions and 66 deletions.
83 changes: 17 additions & 66 deletions src/python/library/tritonclient/grpc/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def __init__(
self._client_stub = service_pb2_grpc.GRPCInferenceServiceStub(self._channel)
self._verbose = verbose

def _return_response(self, response, as_json):
if as_json:
return json.loads(MessageToJson(response, preserving_proto_field_name=True))
else:
return response

async def __aenter__(self):
return self

Expand Down Expand Up @@ -198,12 +204,7 @@ async def get_server_metadata(self, headers=None, as_json=False):
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -225,12 +226,7 @@ async def get_model_metadata(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -252,12 +248,7 @@ async def get_model_config(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -277,12 +268,7 @@ async def get_model_repository_index(self, headers=None, as_json=False):
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand Down Expand Up @@ -349,12 +335,7 @@ async def get_inference_statistics(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand Down Expand Up @@ -384,12 +365,7 @@ async def update_trace_settings(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -407,12 +383,7 @@ async def get_trace_settings(self, model_name=None, headers=None, as_json=False)
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -439,12 +410,7 @@ async def update_log_settings(self, settings, headers=None, as_json=False):
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -460,12 +426,7 @@ async def get_log_settings(self, headers=None, as_json=False):
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand All @@ -487,12 +448,7 @@ async def get_system_shared_memory_status(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand Down Expand Up @@ -562,12 +518,7 @@ async def get_cuda_shared_memory_status(
)
if self._verbose:
print(response)
if as_json:
return json.loads(
MessageToJson(response, preserving_proto_field_name=True)
)
else:
return response
return self._return_response(response, as_json)
except grpc.RpcError as rpc_error:
raise_error_grpc(rpc_error)

Expand Down

0 comments on commit f7c45d3

Please sign in to comment.