Skip to content

Commit

Permalink
Workaround with L0_trt_reformat_free by removing shm checks
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh committed Aug 2, 2024
1 parent a9a2c1c commit a584741
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 44 deletions.
34 changes: 11 additions & 23 deletions src/python/library/tritonclient/grpc/_infer_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,31 +102,19 @@ def validate_data(self):
if cnt != 1:
return

# Skip due to trt reformat free tensor
if "shared_memory_region" in self._input.parameters:
# Using shared memory
if self._input.datatype != "BYTES":
expected_byte_size = num_elements(
self._input.shape
) * get_data_type_byte_size(self._input.datatype)
data_byte_size = self._input.parameters[
"shared_memory_byte_size"
].int64_param
if data_byte_size != expected_byte_size:
raise_error(
"input '{}' got unexpected byte size {}, expected {}".format(
self._input.name, data_byte_size, expected_byte_size
)
)
else:
# Not using shared memory
expected_num_elements = num_elements(self._input.shape)
data_num_elements = num_elements(self._data_shape)
if expected_num_elements != data_num_elements:
raise_error(
"input '{}' got unexpected elements count {}, expected {}".format(
self._input.name, data_num_elements, expected_num_elements
)
return

# Not using shared memory
expected_num_elements = num_elements(self._input.shape)
data_num_elements = num_elements(self._data_shape)
if expected_num_elements != data_num_elements:
raise_error(
"input '{}' got unexpected elements count {}, expected {}".format(
self._input.name, data_num_elements, expected_num_elements
)
)
return

def set_shape(self, shape):
Expand Down
32 changes: 11 additions & 21 deletions src/python/library/tritonclient/http/_infer_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,29 +106,19 @@ def validate_data(self):
if cnt != 1:
return

# Skip due to trt reformat free tensor
if "shared_memory_region" in self._parameters:
# Using shared memory
if self._datatype != "BYTES":
expected_byte_size = num_elements(
self._shape
) * get_data_type_byte_size(self._datatype)
data_byte_size = self._parameters["shared_memory_byte_size"]
if data_byte_size != expected_byte_size:
raise_error(
"input '{}' got unexpected byte size {}, expected {}".format(
self._name, data_byte_size, expected_byte_size
)
)
else:
# Not using shared memory
expected_num_elements = num_elements(self._shape)
data_num_elements = num_elements(self._data_shape)
if expected_num_elements != data_num_elements:
raise_error(
"input '{}' got unexpected elements count {}, expected {}".format(
self._name, data_num_elements, expected_num_elements
)
return

# Not using shared memory
expected_num_elements = num_elements(self._shape)
data_num_elements = num_elements(self._data_shape)
if expected_num_elements != data_num_elements:
raise_error(
"input '{}' got unexpected elements count {}, expected {}".format(
self._name, data_num_elements, expected_num_elements
)
)
return

def set_shape(self, shape):
Expand Down

0 comments on commit a584741

Please sign in to comment.