Skip to content

Commit

Permalink
Add inference request cancellation APIs (#249)
Browse files Browse the repository at this point in the history
* Fix state transitions for re-running requests (#251)

* Add backend/server APIs

* Implement the cancellation APIs

* Only store the state in response factory

* Add unit testing for request cancellation

* Add test

* Add cancellation status

* Add testing for cancelling a request after release

* Handle request re-use

* Enable request reuse test

* Add staged changes

* Add temporary fix for the request state bug

---------

Co-authored-by: Ryan McCormick <[email protected]>
  • Loading branch information
Tabrizian and rmccorm4 committed Sep 13, 2023
1 parent 80fef04 commit 2e2a617
Show file tree
Hide file tree
Showing 11 changed files with 477 additions and 12 deletions.
25 changes: 25 additions & 0 deletions include/triton/core/tritonbackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,31 @@ TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_OutputBufferAttributes(
TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestId(
TRITONBACKEND_Request* request, const char** id);

/// Query whether the request is cancelled or not.
///
/// If possible the backend should terminate any processing and
/// send an error response with cancelled status.
///
/// \param request The inference request.
/// \param is_cancelled Returns true if the request is cancelled otherwise it
/// would return false.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONBACKEND_DECLSPEC TRITONSERVER_Error* TRITONBACKEND_RequestIsCancelled(
TRITONBACKEND_Request* request, bool* is_cancelled);

/// Query whether the response factory is cancelled or not.
///
/// If possible the backend should terminate any processing and
/// send an error response with cancelled status.
///
/// \param factory The response factory
/// \param is_cancelled Returns true if the request is cancelled otherwise it
/// would return false.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONBACKEND_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_ResponseFactoryIsCancelled(
TRITONBACKEND_ResponseFactory* factory, bool* is_cancelled);

/// Get the correlation ID of the request if it is an unsigned integer.
/// Zero indicates that the request does not have a correlation ID.
/// Returns failure if correlation ID for given request is not an unsigned
Expand Down
33 changes: 31 additions & 2 deletions include/triton/core/tritonserver.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct TRITONSERVER_MetricFamily;
/// }
///
#define TRITONSERVER_API_VERSION_MAJOR 1
#define TRITONSERVER_API_VERSION_MINOR 24
#define TRITONSERVER_API_VERSION_MINOR 25

/// Get the TRITONBACKEND API version supported by the Triton shared
/// library. This value can be compared against the
Expand Down Expand Up @@ -308,7 +308,8 @@ typedef enum TRITONSERVER_errorcode_enum {
TRITONSERVER_ERROR_INVALID_ARG,
TRITONSERVER_ERROR_UNAVAILABLE,
TRITONSERVER_ERROR_UNSUPPORTED,
TRITONSERVER_ERROR_ALREADY_EXISTS
TRITONSERVER_ERROR_ALREADY_EXISTS,
TRITONSERVER_ERROR_CANCELLED
} TRITONSERVER_Error_Code;

/// Create a new error object. The caller takes ownership of the
Expand Down Expand Up @@ -1091,6 +1092,34 @@ TRITONSERVER_InferenceRequestSetCorrelationIdString(
struct TRITONSERVER_InferenceRequest* inference_request,
const char* correlation_id);

/// Cancel an inference request. Requests are canceled on a best
/// effort basis and no guarantee is provided that cancelling a
/// request will result in early termination. Note that the
/// inference request cancellation status will be reset after
/// TRITONSERVER_InferAsync is run. This means that if you cancel
/// the request before calling TRITONSERVER_InferAsync
/// the request will not be cancelled.
///
/// \param inference_request The request object.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
TRITONSERVER_InferenceRequestCancel(
struct TRITONSERVER_InferenceRequest* inference_request);

/// Query whether the request is cancelled or not.
///
/// If possible the backend should terminate any processing and
/// send an error response with cancelled status.
///
/// \param inference_request The request object.
/// \param is_cancelled Returns whether the inference request is cancelled or
/// not.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
TRITONSERVER_InferenceRequestIsCancelled(
struct TRITONSERVER_InferenceRequest* inference_request,
bool* is_cancelled);

/// Deprecated. See TRITONSERVER_InferenceRequestPriorityUInt64 instead.
///
/// Get the priority for a request. The default is 0 indicating that
Expand Down
21 changes: 21 additions & 0 deletions src/backend_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,16 @@ TRITONBACKEND_RequestFlags(TRITONBACKEND_Request* request, uint32_t* flags)
return nullptr; // success
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_RequestIsCancelled(
TRITONBACKEND_Request* request, bool* is_cancelled)
{
InferenceRequest* tr = reinterpret_cast<InferenceRequest*>(request);
*is_cancelled = tr->IsCancelled();
return nullptr;
}


TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_RequestCorrelationIdString(
TRITONBACKEND_Request* request, const char** id)
Expand Down Expand Up @@ -1365,6 +1375,17 @@ TRITONBACKEND_ResponseFactorySendFlags(
return nullptr; // success
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONBACKEND_ResponseFactoryIsCancelled(
TRITONBACKEND_ResponseFactory* factory, bool* is_cancelled)
{
std::shared_ptr<InferenceResponseFactory>* response_factory =
reinterpret_cast<std::shared_ptr<InferenceResponseFactory>*>(factory);
*is_cancelled = (*response_factory)->IsCancelled();
return nullptr; // success
}


///
/// TRITONBACKEND_Response
///
Expand Down
1 change: 1 addition & 0 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ InferenceRequest::PrepareForInference()
// inference execution.
inputs_.clear();
override_inputs_.clear();
ResetCancel();

// Renormalize if anything has changed in the inference request in a
// way that could impact renormalization.
Expand Down
7 changes: 6 additions & 1 deletion src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,11 @@ class InferenceRequest {
secondary_stats_aggregator_ = secondary_stats_aggregator;
}

void Cancel() { response_factory_->Cancel(); }
void ResetCancel() { response_factory_->ResetCancel(); }

bool IsCancelled() { return response_factory_->IsCancelled(); }

#endif // TRITON_ENABLE_STATS

private:
Expand Down Expand Up @@ -795,7 +800,7 @@ class InferenceRequest {
std::shared_ptr<SequenceStates> sequence_states_;

// The state of the request.
InferenceRequest::State state_;
std::atomic<InferenceRequest::State> state_;
// Whether this is a null request used for direct sequence batch padding or
// not.
bool null_request_;
Expand Down
9 changes: 8 additions & 1 deletion src/infer_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,16 @@ class InferenceResponseFactory {
std::unique_ptr<InferenceResponse>&&, const uint32_t)>& delegator)
: model_(model), id_(id), allocator_(allocator),
alloc_userp_(alloc_userp), response_fn_(response_fn),
response_userp_(response_userp), response_delegator_(delegator)
response_userp_(response_userp), response_delegator_(delegator),
is_cancelled_(false)
{
}

void Cancel() { is_cancelled_ = true; }
void ResetCancel() { is_cancelled_ = false; }

bool IsCancelled() { return is_cancelled_; }

const ResponseAllocator* Allocator() { return allocator_; }
void* AllocatorUserp() { return alloc_userp_; }

Expand Down Expand Up @@ -118,6 +124,7 @@ class InferenceResponseFactory {
std::function<void(std::unique_ptr<InferenceResponse>&&, const uint32_t)>
response_delegator_;

std::atomic<bool> is_cancelled_;

#ifdef TRITON_ENABLE_TRACING
// Inference trace associated with this response.
Expand Down
6 changes: 4 additions & 2 deletions src/status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ TritonCodeToStatusCode(TRITONSERVER_Error_Code code)
return Status::Code::UNSUPPORTED;
case TRITONSERVER_ERROR_ALREADY_EXISTS:
return Status::Code::ALREADY_EXISTS;

case TRITONSERVER_ERROR_CANCELLED:
return Status::Code::CANCELLED;
default:
break;
}
Expand All @@ -74,7 +75,8 @@ StatusCodeToTritonCode(Status::Code status_code)
return TRITONSERVER_ERROR_UNSUPPORTED;
case Status::Code::ALREADY_EXISTS:
return TRITONSERVER_ERROR_ALREADY_EXISTS;

case Status::Code::CANCELLED:
return TRITONSERVER_ERROR_CANCELLED;
default:
break;
}
Expand Down
44 changes: 38 additions & 6 deletions src/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,14 @@ add_executable(
backend_output_detail_test.cc
)

target_include_directories(
backend_output_detail_test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${GTEST_INCLUDE_DIRS}
)

set_target_properties(
backend_output_detail_test
PROPERTIES
Expand All @@ -552,16 +560,40 @@ set_target_properties(
INSTALL_RPATH ""
)

target_include_directories(
target_link_libraries(
backend_output_detail_test
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${GTEST_INCLUDE_DIRS}
triton-common-error # from repo-common
triton-common-logging # from repo-common
triton-core
GTest::gtest
GTest::gtest_main
)

install(
TARGETS backend_output_detail_test
RUNTIME DESTINATION bin
)

#
# Request Cancellation Unittest
#
add_executable(
request_cancellation_test
request_cancellation_test.cc
)

set_target_properties(
request_cancellation_test
PROPERTIES
SKIP_BUILD_RPATH TRUE
BUILD_WITH_INSTALL_RPATH TRUE
INSTALL_RPATH_USE_LINK_PATH FALSE
INSTALL_RPATH ""
)

target_link_libraries(
backend_output_detail_test
request_cancellation_test
PRIVATE
triton-common-error # from repo-common
triton-common-logging # from repo-common
Expand All @@ -571,6 +603,6 @@ target_link_libraries(
)

install(
TARGETS backend_output_detail_test
TARGETS request_cancellation_test
RUNTIME DESTINATION bin
)
Loading

0 comments on commit 2e2a617

Please sign in to comment.