From a58474193129aca58df8b6123d45bb8f5021a44a Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 2 Aug 2024 14:46:35 -0700 Subject: [PATCH] Workaround with L0_trt_reformat_free by removing shm checks --- .../library/tritonclient/grpc/_infer_input.py | 34 ++++++------------- .../library/tritonclient/http/_infer_input.py | 32 ++++++----------- 2 files changed, 22 insertions(+), 44 deletions(-) diff --git a/src/python/library/tritonclient/grpc/_infer_input.py b/src/python/library/tritonclient/grpc/_infer_input.py index 89d944dff..d0975b31f 100755 --- a/src/python/library/tritonclient/grpc/_infer_input.py +++ b/src/python/library/tritonclient/grpc/_infer_input.py @@ -102,31 +102,19 @@ def validate_data(self): if cnt != 1: return + # Skip due to trt reformat free tensor if "shared_memory_region" in self._input.parameters: - # Using shared memory - if self._input.datatype != "BYTES": - expected_byte_size = num_elements( - self._input.shape - ) * get_data_type_byte_size(self._input.datatype) - data_byte_size = self._input.parameters[ - "shared_memory_byte_size" - ].int64_param - if data_byte_size != expected_byte_size: - raise_error( - "input '{}' got unexpected byte size {}, expected {}".format( - self._input.name, data_byte_size, expected_byte_size - ) - ) - else: - # Not using shared memory - expected_num_elements = num_elements(self._input.shape) - data_num_elements = num_elements(self._data_shape) - if expected_num_elements != data_num_elements: - raise_error( - "input '{}' got unexpected elements count {}, expected {}".format( - self._input.name, data_num_elements, expected_num_elements - ) + return + + # Not using shared memory + expected_num_elements = num_elements(self._input.shape) + data_num_elements = num_elements(self._data_shape) + if expected_num_elements != data_num_elements: + raise_error( + "input '{}' got unexpected elements count {}, expected {}".format( + self._input.name, data_num_elements, expected_num_elements ) + ) return def set_shape(self, shape): diff --git a/src/python/library/tritonclient/http/_infer_input.py b/src/python/library/tritonclient/http/_infer_input.py index af650d3ed..e0d3f19fb 100755 --- a/src/python/library/tritonclient/http/_infer_input.py +++ b/src/python/library/tritonclient/http/_infer_input.py @@ -106,29 +106,19 @@ def validate_data(self): if cnt != 1: return + # Skip due to trt reformat free tensor if "shared_memory_region" in self._parameters: - # Using shared memory - if self._datatype != "BYTES": - expected_byte_size = num_elements( - self._shape - ) * get_data_type_byte_size(self._datatype) - data_byte_size = self._parameters["shared_memory_byte_size"] - if data_byte_size != expected_byte_size: - raise_error( - "input '{}' got unexpected byte size {}, expected {}".format( - self._name, data_byte_size, expected_byte_size - ) - ) - else: - # Not using shared memory - expected_num_elements = num_elements(self._shape) - data_num_elements = num_elements(self._data_shape) - if expected_num_elements != data_num_elements: - raise_error( - "input '{}' got unexpected elements count {}, expected {}".format( - self._name, data_num_elements, expected_num_elements - ) + return + + # Not using shared memory + expected_num_elements = num_elements(self._shape) + data_num_elements = num_elements(self._data_shape) + if expected_num_elements != data_num_elements: + raise_error( + "input '{}' got unexpected elements count {}, expected {}".format( + self._name, data_num_elements, expected_num_elements ) + ) return def set_shape(self, shape):