Skip to content

Commit

Permalink
Update PA async callback to only run with final response (#351)
Browse files Browse the repository at this point in the history
* Update PA async callback to only run with final response

* Address feedback

* Address feedback

* Address feedback

* Fix bug
  • Loading branch information
matthewkotila committed Aug 8, 2023
1 parent 96e67f2 commit c156dec
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 6 deletions.
8 changes: 7 additions & 1 deletion src/c++/library/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ struct InferOptions {
: model_name_(model_name), model_version_(""), request_id_(""),
sequence_id_(0), sequence_id_str_(""), sequence_start_(false),
sequence_end_(false), priority_(0), server_timeout_(0),
client_timeout_(0)
client_timeout_(0), triton_enable_empty_final_response_(false)
{
}
/// The name of the model to run inference.
Expand Down Expand Up @@ -219,6 +219,8 @@ struct InferOptions {
// timeout < 1000 microseconds will be rounded down to 0 milliseconds and have
// no effect.
uint64_t client_timeout_;
/// Whether to tell Triton to enable an empty final response.
bool triton_enable_empty_final_response_;
};

//==============================================================================
Expand Down Expand Up @@ -511,6 +513,10 @@ class InferResult {
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const = 0;

/// Get final response bool of the request which generated this response.
/// \return Error object indicating the success or failure.
virtual Error IsFinalResponse(bool* is_final_response) const = 0;

/// Get the result data as a vector of strings. The vector will
/// receive a copy of result data. An error will be generated if
/// the datatype of output is not 'BYTES'.
Expand Down
24 changes: 24 additions & 0 deletions src/c++/library/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class InferResultGrpc : public InferResult {
Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const override;
Error IsFinalResponse(bool* is_final_response) const override;
Error StringData(
const std::string& output_name,
std::vector<std::string>* string_result) const override;
Expand All @@ -209,6 +210,7 @@ class InferResultGrpc : public InferResult {
std::shared_ptr<inference::ModelInferResponse> response_;
std::shared_ptr<inference::ModelStreamInferResponse> stream_response_;
Error request_status_;
bool is_final_response_{true};
};

Error
Expand Down Expand Up @@ -310,6 +312,16 @@ InferResultGrpc::RawData(
return Error::Success;
}

Error
InferResultGrpc::IsFinalResponse(bool* is_final_response) const
{
if (is_final_response == nullptr) {
return Error("is_final_response cannot be nullptr");
}
*is_final_response = is_final_response_;
return Error::Success;
}

Error
InferResultGrpc::StringData(
const std::string& output_name,
Expand Down Expand Up @@ -367,6 +379,11 @@ InferResultGrpc::InferResultGrpc(
std::make_pair(output.name(), std::make_pair(buf, byte_size)));
index++;
}
const auto& is_final_response_itr{
response_->parameters().find("triton_final_response")};
if (is_final_response_itr != response_->parameters().end()) {
is_final_response_ = is_final_response_itr->second.bool_param();
}
}

InferResultGrpc::InferResultGrpc(
Expand All @@ -387,6 +404,11 @@ InferResultGrpc::InferResultGrpc(
std::make_pair(output.name(), std::make_pair(buf, byte_size)));
index++;
}
const auto& is_final_response_itr{
response_->parameters().find("triton_final_response")};
if (is_final_response_itr != response_->parameters().end()) {
is_final_response_ = is_final_response_itr->second.bool_param();
}
}

//==============================================================================
Expand Down Expand Up @@ -1347,6 +1369,8 @@ InferenceServerGrpcClient::PreRunProcessing(
infer_request_.set_id(options.request_id_);

infer_request_.mutable_parameters()->clear();
(*infer_request_.mutable_parameters())["triton_enable_empty_final_response"]
.set_bool_param(options.triton_enable_empty_final_response_);
if ((options.sequence_id_ != 0) || (options.sequence_id_str_ != "")) {
if (options.sequence_id_ != 0) {
(*infer_request_.mutable_parameters())["sequence_id"].set_int64_param(
Expand Down
12 changes: 12 additions & 0 deletions src/c++/library/http_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ class InferResultHttp : public InferResult {
Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const override;
Error IsFinalResponse(bool* is_final_response) const override;
Error StringData(
const std::string& output_name,
std::vector<std::string>* string_result) const override;
Expand Down Expand Up @@ -767,6 +768,7 @@ class InferResultHttp : public InferResult {
std::shared_ptr<HttpInferRequest> infer_request_;

bool binary_data_{true};
bool is_final_response_{true};
};

void
Expand Down Expand Up @@ -939,6 +941,16 @@ InferResultHttp::RawData(
return Error::Success;
}

Error
InferResultHttp::IsFinalResponse(bool* is_final_response) const
{
if (is_final_response == nullptr) {
return Error("is_final_response cannot be nullptr");
}
*is_final_response = is_final_response_;
return Error::Success;
}

Error
InferResultHttp::StringData(
const std::string& output_name,
Expand Down
11 changes: 10 additions & 1 deletion src/c++/perf_analyzer/client_backend/client_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ struct InferOptions {
explicit InferOptions(const std::string& model_name)
: model_name_(model_name), model_version_(""), request_id_(""),
sequence_id_(0), sequence_id_str_(""), sequence_start_(false),
sequence_end_(false)
sequence_end_(false), triton_enable_empty_final_response_(true)
{
}
/// The name of the model to run inference.
Expand Down Expand Up @@ -228,6 +228,8 @@ struct InferOptions {
/// sequence. Default value is False. This argument is ignored if
/// 'sequence_id' is 0.
bool sequence_end_;
/// Whether to tell Triton to enable an empty final response.
bool triton_enable_empty_final_response_;
};

struct SslOptionsBase {
Expand Down Expand Up @@ -614,6 +616,13 @@ class InferResult {
virtual Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const = 0;

/// Get final response bool of the request which generated this response.
/// \return Error object indicating the success or failure.
virtual Error IsFinalResponse(bool* is_final_response) const
{
return Error("InferResult::IsFinalResponse() not implemented");
};
};

}}} // namespace triton::perfanalyzer::clientbackend
11 changes: 10 additions & 1 deletion src/c++/perf_analyzer/client_backend/mock_client_backend.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -118,6 +118,15 @@ class MockInferResult : public InferResult {
return Error::Success;
}

Error IsFinalResponse(bool* is_final_response) const override
{
if (is_final_response == nullptr) {
return Error("is_final_response cannot be nullptr");
}
*is_final_response = true;
return Error::Success;
}

private:
std::string req_id_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,8 @@ TritonClientBackend::ParseInferOptionsToTriton(
triton_options->sequence_start_ = options.sequence_start_;
triton_options->sequence_end_ = options.sequence_end_;
}
triton_options->triton_enable_empty_final_response_ =
options.triton_enable_empty_final_response_;
}


Expand Down Expand Up @@ -818,6 +820,13 @@ TritonInferResult::RawData(
return Error::Success;
}

Error
TritonInferResult::IsFinalResponse(bool* is_final_response) const
{
RETURN_IF_TRITON_ERROR(result_->IsFinalResponse(is_final_response));
return Error::Success;
}

//==============================================================================

}}}} // namespace triton::perfanalyzer::clientbackend::tritonremote
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ class TritonInferResult : public InferResult {
Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const override;
/// See InferResult::IsFinalResponse()
Error IsFinalResponse(bool* is_final_response) const override;

private:
std::unique_ptr<tc::InferResult> result_;
Expand Down
10 changes: 7 additions & 3 deletions src/c++/perf_analyzer/infer_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result)
{
std::shared_ptr<cb::InferResult> result_ptr(result);
if (thread_stat_->cb_status_.IsOk()) {
// TODO TMA-1257 use final response parameter from grpc client
bool final_response = true;
// Add the request timestamp to thread Timestamp vector with
// proper locking
std::lock_guard<std::mutex> lock(thread_stat_->mu_);
Expand All @@ -251,7 +249,13 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result)
const auto& it = async_req_map_.find(request_id);
if (it != async_req_map_.end()) {
it->second.end_times.push_back(end_time_async);
if (final_response) {
bool is_final_response{false};
thread_stat_->cb_status_ =
result_ptr->IsFinalResponse(&is_final_response);
if (thread_stat_->cb_status_.IsOk() == false) {
return;
}
if (is_final_response) {
thread_stat_->request_timestamps_.emplace_back(std::make_tuple(
it->second.start_time_, it->second.end_times,
it->second.sequence_end_, it->second.delayed_));
Expand Down

0 comments on commit c156dec

Please sign in to comment.