diff --git a/src/c++/library/common.h b/src/c++/library/common.h index 675dc960f..1f0494000 100644 --- a/src/c++/library/common.h +++ b/src/c++/library/common.h @@ -161,7 +161,7 @@ struct InferOptions { : model_name_(model_name), model_version_(""), request_id_(""), sequence_id_(0), sequence_id_str_(""), sequence_start_(false), sequence_end_(false), priority_(0), server_timeout_(0), - client_timeout_(0) + client_timeout_(0), triton_enable_empty_final_response_(false) { } /// The name of the model to run inference. @@ -219,6 +219,8 @@ struct InferOptions { // timeout < 1000 microseconds will be rounded down to 0 milliseconds and have // no effect. uint64_t client_timeout_; + /// Whether to tell Triton to enable an empty final response. + bool triton_enable_empty_final_response_; }; //============================================================================== @@ -511,6 +513,10 @@ class InferResult { const std::string& output_name, const uint8_t** buf, size_t* byte_size) const = 0; + /// Get final response bool of the request which generated this response. + /// \return Error object indicating the success or failure. + virtual Error IsFinalResponse(bool* is_final_response) const = 0; + /// Get the result data as a vector of strings. The vector will /// receive a copy of result data. An error will be generated if /// the datatype of output is not 'BYTES'. diff --git a/src/c++/library/grpc_client.cc b/src/c++/library/grpc_client.cc index 9e45620ac..eba18a33e 100644 --- a/src/c++/library/grpc_client.cc +++ b/src/c++/library/grpc_client.cc @@ -189,6 +189,7 @@ class InferResultGrpc : public InferResult { Error RawData( const std::string& output_name, const uint8_t** buf, size_t* byte_size) const override; + Error IsFinalResponse(bool* is_final_response) const override; Error StringData( const std::string& output_name, std::vector* string_result) const override; @@ -209,6 +210,7 @@ class InferResultGrpc : public InferResult { std::shared_ptr response_; std::shared_ptr stream_response_; Error request_status_; + bool is_final_response_{true}; }; Error @@ -310,6 +312,16 @@ InferResultGrpc::RawData( return Error::Success; } +Error +InferResultGrpc::IsFinalResponse(bool* is_final_response) const +{ + if (is_final_response == nullptr) { + return Error("is_final_response cannot be nullptr"); + } + *is_final_response = is_final_response_; + return Error::Success; +} + Error InferResultGrpc::StringData( const std::string& output_name, @@ -367,6 +379,11 @@ InferResultGrpc::InferResultGrpc( std::make_pair(output.name(), std::make_pair(buf, byte_size))); index++; } + const auto& is_final_response_itr{ + response_->parameters().find("triton_final_response")}; + if (is_final_response_itr != response_->parameters().end()) { + is_final_response_ = is_final_response_itr->second.bool_param(); + } } InferResultGrpc::InferResultGrpc( @@ -387,6 +404,11 @@ InferResultGrpc::InferResultGrpc( std::make_pair(output.name(), std::make_pair(buf, byte_size))); index++; } + const auto& is_final_response_itr{ + response_->parameters().find("triton_final_response")}; + if (is_final_response_itr != response_->parameters().end()) { + is_final_response_ = is_final_response_itr->second.bool_param(); + } } //============================================================================== @@ -1347,6 +1369,8 @@ InferenceServerGrpcClient::PreRunProcessing( infer_request_.set_id(options.request_id_); infer_request_.mutable_parameters()->clear(); + (*infer_request_.mutable_parameters())["triton_enable_empty_final_response"] + .set_bool_param(options.triton_enable_empty_final_response_); if ((options.sequence_id_ != 0) || (options.sequence_id_str_ != "")) { if (options.sequence_id_ != 0) { (*infer_request_.mutable_parameters())["sequence_id"].set_int64_param( diff --git a/src/c++/library/http_client.cc b/src/c++/library/http_client.cc index 232aa24fa..4258ead12 100644 --- a/src/c++/library/http_client.cc +++ b/src/c++/library/http_client.cc @@ -739,6 +739,7 @@ class InferResultHttp : public InferResult { Error RawData( const std::string& output_name, const uint8_t** buf, size_t* byte_size) const override; + Error IsFinalResponse(bool* is_final_response) const override; Error StringData( const std::string& output_name, std::vector* string_result) const override; @@ -767,6 +768,7 @@ class InferResultHttp : public InferResult { std::shared_ptr infer_request_; bool binary_data_{true}; + bool is_final_response_{true}; }; void @@ -939,6 +941,16 @@ InferResultHttp::RawData( return Error::Success; } +Error +InferResultHttp::IsFinalResponse(bool* is_final_response) const +{ + if (is_final_response == nullptr) { + return Error("is_final_response cannot be nullptr"); + } + *is_final_response = is_final_response_; + return Error::Success; +} + Error InferResultHttp::StringData( const std::string& output_name, diff --git a/src/c++/perf_analyzer/client_backend/client_backend.h b/src/c++/perf_analyzer/client_backend/client_backend.h index fc58fb87c..4903da61d 100644 --- a/src/c++/perf_analyzer/client_backend/client_backend.h +++ b/src/c++/perf_analyzer/client_backend/client_backend.h @@ -199,7 +199,7 @@ struct InferOptions { explicit InferOptions(const std::string& model_name) : model_name_(model_name), model_version_(""), request_id_(""), sequence_id_(0), sequence_id_str_(""), sequence_start_(false), - sequence_end_(false) + sequence_end_(false), triton_enable_empty_final_response_(true) { } /// The name of the model to run inference. @@ -228,6 +228,8 @@ struct InferOptions { /// sequence. Default value is False. This argument is ignored if /// 'sequence_id' is 0. bool sequence_end_; + /// Whether to tell Triton to enable an empty final response. + bool triton_enable_empty_final_response_; }; struct SslOptionsBase { @@ -614,6 +616,13 @@ class InferResult { virtual Error RawData( const std::string& output_name, const uint8_t** buf, size_t* byte_size) const = 0; + + /// Get final response bool of the request which generated this response. + /// \return Error object indicating the success or failure. + virtual Error IsFinalResponse(bool* is_final_response) const + { + return Error("InferResult::IsFinalResponse() not implemented"); + }; }; }}} // namespace triton::perfanalyzer::clientbackend diff --git a/src/c++/perf_analyzer/client_backend/mock_client_backend.h b/src/c++/perf_analyzer/client_backend/mock_client_backend.h index a9ddef9be..28a568eb0 100644 --- a/src/c++/perf_analyzer/client_backend/mock_client_backend.h +++ b/src/c++/perf_analyzer/client_backend/mock_client_backend.h @@ -1,4 +1,4 @@ -// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -118,6 +118,15 @@ class MockInferResult : public InferResult { return Error::Success; } + Error IsFinalResponse(bool* is_final_response) const override + { + if (is_final_response == nullptr) { + return Error("is_final_response cannot be nullptr"); + } + *is_final_response = true; + return Error::Success; + } + private: std::string req_id_; }; diff --git a/src/c++/perf_analyzer/client_backend/triton/triton_client_backend.cc b/src/c++/perf_analyzer/client_backend/triton/triton_client_backend.cc index 45241c13b..7ac2188a6 100644 --- a/src/c++/perf_analyzer/client_backend/triton/triton_client_backend.cc +++ b/src/c++/perf_analyzer/client_backend/triton/triton_client_backend.cc @@ -588,6 +588,8 @@ TritonClientBackend::ParseInferOptionsToTriton( triton_options->sequence_start_ = options.sequence_start_; triton_options->sequence_end_ = options.sequence_end_; } + triton_options->triton_enable_empty_final_response_ = + options.triton_enable_empty_final_response_; } @@ -818,6 +820,13 @@ TritonInferResult::RawData( return Error::Success; } +Error +TritonInferResult::IsFinalResponse(bool* is_final_response) const +{ + RETURN_IF_TRITON_ERROR(result_->IsFinalResponse(is_final_response)); + return Error::Success; +} + //============================================================================== }}}} // namespace triton::perfanalyzer::clientbackend::tritonremote diff --git a/src/c++/perf_analyzer/client_backend/triton/triton_client_backend.h b/src/c++/perf_analyzer/client_backend/triton/triton_client_backend.h index 80cbf9a17..5b25e8297 100644 --- a/src/c++/perf_analyzer/client_backend/triton/triton_client_backend.h +++ b/src/c++/perf_analyzer/client_backend/triton/triton_client_backend.h @@ -329,6 +329,8 @@ class TritonInferResult : public InferResult { Error RawData( const std::string& output_name, const uint8_t** buf, size_t* byte_size) const override; + /// See InferResult::IsFinalResponse() + Error IsFinalResponse(bool* is_final_response) const override; private: std::unique_ptr result_; diff --git a/src/c++/perf_analyzer/infer_context.cc b/src/c++/perf_analyzer/infer_context.cc index 4b5e5a14d..7c4a36944 100644 --- a/src/c++/perf_analyzer/infer_context.cc +++ b/src/c++/perf_analyzer/infer_context.cc @@ -237,8 +237,6 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result) { std::shared_ptr result_ptr(result); if (thread_stat_->cb_status_.IsOk()) { - // TODO TMA-1257 use final response parameter from grpc client - bool final_response = true; // Add the request timestamp to thread Timestamp vector with // proper locking std::lock_guard lock(thread_stat_->mu_); @@ -251,7 +249,13 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result) const auto& it = async_req_map_.find(request_id); if (it != async_req_map_.end()) { it->second.end_times.push_back(end_time_async); - if (final_response) { + bool is_final_response{false}; + thread_stat_->cb_status_ = + result_ptr->IsFinalResponse(&is_final_response); + if (thread_stat_->cb_status_.IsOk() == false) { + return; + } + if (is_final_response) { thread_stat_->request_timestamps_.emplace_back(std::make_tuple( it->second.start_time_, it->second.end_times, it->second.sequence_end_, it->second.delayed_));