Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend Shape and Datatype Access API #247

Merged
merged 11 commits into from
Sep 6, 2023
40 changes: 39 additions & 1 deletion include/triton/core/tritonbackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ struct TRITONBACKEND_Batcher;
/// }
///
#define TRITONBACKEND_API_VERSION_MAJOR 1
#define TRITONBACKEND_API_VERSION_MINOR 15
#define TRITONBACKEND_API_VERSION_MINOR 16

/// Get the TRITONBACKEND API version supported by Triton. This value
/// can be compared against the TRITONBACKEND_API_VERSION_MAJOR and
Expand Down Expand Up @@ -1569,6 +1569,44 @@ TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelBatchInitialize(
TRITONBACKEND_ISPEC TRITONSERVER_Error* TRITONBACKEND_ModelBatchFinalize(
void* userp);

/// Get all information about an output tensor by its name. The caller does
/// not own any of the referenced return values and must not modify or delete
/// them. The lifetime of all returned values extends until 'response' is
/// deleted.
///
/// \param response The response object.
/// \param name The name of the output.
/// \param datatype Returns the type of the output.
/// \param shape Returns the shape of the output.
/// \param dim_count Returns the number of dimensions of the returned
/// shape.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONBACKEND_ISPEC TRITONSERVER_Error*
TRITONBACKEND_InferenceResponseOutputByName(
TRITONBACKEND_Response* response, const char* name,
TRITONSERVER_DataType* datatype, const int64_t** shape,
uint64_t* dim_count);

/// Get all information about an output tensor by its index. The caller does
/// not own any of the referenced return values and must not modify or delete
/// them. The lifetime of all returned values extends until 'response' is
/// deleted.
///
/// \param response The response object.
/// \param index The index of the output tensor, must be 0 <= index <
/// count, where 'count' is the value returned by
/// TRITONSERVER_InferenceResponseOutputCount.
/// \param name Returns the name of the output.
/// \param datatype Returns the type of the output.
/// \param shape Returns the shape of the output.
/// \param dim_count Returns the number of dimensions of the returned
/// shape.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_InferenceResponseOutput(
TRITONBACKEND_Response* response, const uint32_t index, const char** name,
TRITONSERVER_DataType* datatype, const int64_t** shape,
uint64_t* dim_count);

#ifdef __cplusplus
}
#endif
54 changes: 54 additions & 0 deletions src/backend_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1750,6 +1750,60 @@ TRITONBACKEND_BackendAttributeSetParallelModelInstanceLoading(
return nullptr;
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_InferenceResponseOutputByName(
TRITONBACKEND_Response* response, const char* name,
TRITONSERVER_DataType* datatype, const int64_t** shape, uint64_t* dim_count)
{
InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);

const auto& outputs = tr->Outputs();
uint32_t output_count = outputs.size();
std::string output_name = std::string(name);

for (uint32_t idx = 0; idx < output_count; ++idx) {
if (outputs[idx].Name() == output_name) {
*datatype = DataTypeToTriton(outputs[idx].DType());
const std::vector<int64_t>& oshape = outputs[idx].Shape();
*shape = &oshape[0];
*dim_count = oshape.size();
return nullptr; // success
}
}
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_NOT_FOUND,
("Output name " + output_name + "not found.").c_str());
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_InferenceResponseOutput(
TRITONBACKEND_Response* response, const uint32_t index, const char** name,
TRITONSERVER_DataType* datatype, const int64_t** shape, uint64_t* dim_count)
{
InferenceResponse* tr = reinterpret_cast<InferenceResponse*>(response);

const auto& outputs = tr->Outputs();
if (index >= outputs.size()) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
("out of bounds index " + std::to_string(index) +
std::string(": response has ") + std::to_string(outputs.size()) +
" outputs")
.c_str());
}

const InferenceResponse::Output& output = outputs[index];

*name = output.Name().c_str();
*datatype = DataTypeToTriton(output.DType());

const std::vector<int64_t>& oshape = output.Shape();
*shape = &oshape[0];
*dim_count = oshape.size();

return nullptr; // success
}

} // extern C

}} // namespace triton::core
40 changes: 40 additions & 0 deletions src/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,43 @@ install(
TARGETS register_api_test
RUNTIME DESTINATION bin
)

#
# Backend Output Detail Unittest
#
add_executable(
backend_output_detail_test
backend_output_detail_test.cc
)

set_target_properties(
backend_output_detail_test
PROPERTIES
SKIP_BUILD_RPATH TRUE
BUILD_WITH_INSTALL_RPATH TRUE
INSTALL_RPATH_USE_LINK_PATH FALSE
INSTALL_RPATH ""
)

target_include_directories(
backend_output_detail_test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${GTEST_INCLUDE_DIRS}
)

target_link_libraries(
backend_output_detail_test
PRIVATE
triton-common-error # from repo-common
triton-common-logging # from repo-common
triton-core
GTest::gtest
GTest::gtest_main
)

install(
TARGETS backend_output_detail_test
RUNTIME DESTINATION bin
)
Loading