diff --git a/src/constants.h b/src/constants.h index 119d1e9d2..208def668 100644 --- a/src/constants.h +++ b/src/constants.h @@ -62,6 +62,7 @@ constexpr char kPythonBackend[] = "python"; #ifdef TRITON_ENABLE_ENSEMBLE constexpr char kEnsemblePlatform[] = "ensemble"; +constexpr uint64_t ENSEMBLE_CB_POOL_SIZE = 8u; #endif // TRITON_ENABLE_ENSEMBLE constexpr char kTensorRTExecutionAccelerator[] = "tensorrt"; diff --git a/src/ensemble_scheduler/ensemble_scheduler.cc b/src/ensemble_scheduler/ensemble_scheduler.cc index 90f1e0be8..aebf101d1 100644 --- a/src/ensemble_scheduler/ensemble_scheduler.cc +++ b/src/ensemble_scheduler/ensemble_scheduler.cc @@ -1,4 +1,4 @@ -// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -45,6 +45,32 @@ class EnsembleContext; using IterationCount = size_t; +// Check if the model is configured to preserve the order of responses. +// This is critical for async execution of ResponseComplete callbacks. +inline bool +preserve_responses_order(const inference::ModelConfig& config) +{ + uint64_t total_instance_groups = 0; + for (const auto& group : config.instance_group()) { + total_instance_groups += group.count(); + } + + // Case 1: Sequence batching is enabled + // Case 2: Dynamic batching is disabled and there is only one instance group + // Case 3: Dynamic batching is enabled and preserve_ordering is true + // Case 4: Model transaction policy is decoupled (breaks RequestTracker + // lifecycle) + // Note: Although decoupled models do not preserve the order of + // responses, if the final response callback is not executed in the last step, + // the RequestTracker object will be freed prematurely and led to segmentation + // fault. + return config.has_sequence_batching() || + (!config.has_dynamic_batching() && total_instance_groups <= 1) || + (config.has_dynamic_batching() && + config.dynamic_batching().preserve_ordering()) || + config.model_transaction_policy().decoupled(); +} + // Request tracker is passed as 'userp' in RequestRelease function and used // to manage the lifecycle of the ensemble request class RequestTracker { @@ -52,10 +78,12 @@ class RequestTracker { explicit RequestTracker( std::unique_ptr&& request, uint64_t compute_start_ns, MetricModelReporter* metric_reporter, - InferenceStatsAggregator* stats_aggregator) + InferenceStatsAggregator* stats_aggregator, + triton::common::ThreadPool* callback_pool) : inflight_request_counter_(1), request_(std::move(request)), compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter), - stats_aggregator_(stats_aggregator), status_(Status::Success) + stats_aggregator_(stats_aggregator), status_(Status::Success), + callback_pool_(callback_pool) { } @@ -70,6 +98,8 @@ class RequestTracker { return context_stats_aggregator_; } + triton::common::ThreadPool* CallbackPool() const { return callback_pool_; } + void IncrementCounter() { std::lock_guard lk(mtx_); @@ -120,6 +150,7 @@ class RequestTracker { InferenceStatsAggregator* stats_aggregator_; InferenceStatsAggregator context_stats_aggregator_; Status status_; + triton::common::ThreadPool* const callback_pool_; }; // Step is used as 'userp' and keeps ensemble context alive @@ -129,9 +160,9 @@ class RequestTracker { struct Step { Step( size_t step_idx, const InferenceRequest::SequenceId& correlation_id, - uint32_t flags) + uint32_t flags, bool preserve_responses_order) : correlation_id_(correlation_id), flags_(flags), response_flags_(0), - step_idx_(step_idx) + preserve_responses_order_(preserve_responses_order), step_idx_(step_idx) { } @@ -154,7 +185,7 @@ struct Step { // returning from the callback. uint32_t response_flags_; TRITONSERVER_InferenceResponse* response_; - + const bool preserve_responses_order_; size_t step_idx_; }; @@ -237,7 +268,7 @@ class EnsembleContext { MetricModelReporter* metric_reporter, InferenceStatsAggregator* stats_aggregator, InferenceServer* is, EnsembleInfo* info, std::unique_ptr& request, - cudaStream_t stream); + cudaStream_t stream, triton::common::ThreadPool* callback_pool); // Perform transition on 'context' state given the information of // 'completed_step' @@ -326,6 +357,8 @@ class EnsembleContext { void CacheEnsembleTopLevelRequest( std::unique_ptr& response); + triton::common::ThreadPool* CallbackPool() const { return callback_pool_; } + InferenceServer* is_; EnsembleInfo* info_; @@ -375,20 +408,26 @@ class EnsembleContext { TRITONSERVER_ResponseAllocator, decltype(&TRITONSERVER_ResponseAllocatorDelete)> allocator_; + + // The thread pool used to execute ensemble callbacks and reduce e2e latency. + // The thread pool is managed by InferenceServer. + triton::common::ThreadPool* const callback_pool_; }; EnsembleContext::EnsembleContext( MetricModelReporter* metric_reporter, InferenceStatsAggregator* stats_aggregator, InferenceServer* is, EnsembleInfo* info, std::unique_ptr& request, - cudaStream_t stream) + cudaStream_t stream, triton::common::ThreadPool* callback_pool) : is_(is), info_(info), stream_(stream), inflight_step_counter_(0), - allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete) + allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete), + callback_pool_(callback_pool) { uint64_t compute_start_ns = 0; INFER_STATS_SET_TIMESTAMP(compute_start_ns); request_tracker_ = new RequestTracker( - std::move(request), compute_start_ns, metric_reporter, stats_aggregator); + std::move(request), compute_start_ns, metric_reporter, stats_aggregator, + callback_pool); auto& lrequest = request_tracker_->Request(); @@ -603,14 +642,25 @@ void EnsembleContext::RequestComplete( TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) { - if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) { - LOG_TRITONSERVER_ERROR( - TRITONSERVER_InferenceRequestDelete(request), - "deleting ensemble inference request"); - auto request_tracker = reinterpret_cast(userp); - if (request_tracker->DecrementCounter()) { - delete request_tracker; + auto request_tracker = reinterpret_cast(userp); + auto pool = request_tracker->CallbackPool(); + auto fn = [request, flags, request_tracker]() { + if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(request), + "deleting ensemble inference request"); + if (request_tracker->DecrementCounter()) { + delete request_tracker; + } } + }; + + // Attempt to enqueue the callback. If all workers are busy and queue is at + // capacity, execute the callback immediately in current thread. + if (pool->TaskQueueSize() < pool->Size()) { + pool->Enqueue(fn); + } else { + fn(); } } @@ -618,14 +668,31 @@ void EnsembleContext::ResponseComplete( TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) { - auto step_ptr = std::unique_ptr(reinterpret_cast(userp)); - step_ptr->response_flags_ = flags; - step_ptr->response_ = response; - - EnsembleContext::Proceed(step_ptr->ctx_, step_ptr); - // Expecting more responses - if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) { - step_ptr.release(); + auto step_raw_ptr = reinterpret_cast(userp); + auto pool = step_raw_ptr->ctx_->CallbackPool(); + auto fn = [response, flags, step_raw_ptr]() { + auto step_ptr = std::unique_ptr(step_raw_ptr); + step_ptr->response_flags_ = flags; + step_ptr->response_ = response; + + EnsembleContext::Proceed(step_ptr->ctx_, step_ptr); + // Expecting more responses + if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) { + step_ptr.release(); + } + }; + + // Attempt to enqueue the callback. If all workers are busy and queue is at + // capacity, execute the callback immediately in current thread. + // Note: The async callback optimization does not guarantee the order of + // responses and expolit cases where responses can be out-of-order. For models + // required to preserve the order of responses, the response callbacks must be + // executed in the same thread synchronously. + if (!step_raw_ptr->preserve_responses_order_ && + pool->TaskQueueSize() < pool->Size()) { + pool->Enqueue(fn); + } else { + fn(); } } @@ -971,8 +1038,8 @@ EnsembleContext::InitStep( for (const auto& pair : istep.output_to_tensor_) { irequest->AddOriginalRequestedOutput(pair.first); } - - step->reset(new Step(step_idx, correlation_id, flags)); + const bool preserve_order = preserve_responses_order(model->Config()); + step->reset(new Step(step_idx, correlation_id, flags, preserve_order)); irequest->SetId(request_id_); irequest->SetCorrelationId(correlation_id); @@ -1448,7 +1515,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr& request) RETURN_IF_ERROR(request->SetState(InferenceRequest::State::EXECUTING)); std::shared_ptr context(new EnsembleContext( metric_reporter_.get(), stats_aggregator_, is_, info_.get(), request, - stream_)); + stream_, callback_pool_)); EnsembleContext::Proceed(context); return Status::Success; } @@ -1537,6 +1604,7 @@ EnsembleScheduler::EnsembleScheduler( info_->tensor_to_prev_step_.emplace(pair.second, step_idx); } } + callback_pool_ = is_->EnsembleCallbackPool(); } EnsembleScheduler::~EnsembleScheduler() diff --git a/src/ensemble_scheduler/ensemble_scheduler.h b/src/ensemble_scheduler/ensemble_scheduler.h index 51473ea71..9527ab802 100644 --- a/src/ensemble_scheduler/ensemble_scheduler.h +++ b/src/ensemble_scheduler/ensemble_scheduler.h @@ -1,4 +1,4 @@ -// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -36,6 +36,7 @@ #include "scheduler.h" #include "scheduler_utils.h" #include "status.h" +#include "triton/common/thread_pool.h" #ifdef TRITON_ENABLE_GPU #include @@ -107,6 +108,8 @@ class EnsembleScheduler : public Scheduler { // \see Scheduler::Stop() void Stop() override {} + triton::common::ThreadPool* CallbackPool() const { return callback_pool_; } + private: EnsembleScheduler( InferenceStatsAggregator* const stats_aggregator, @@ -128,6 +131,10 @@ class EnsembleScheduler : public Scheduler { cudaStream_t stream_; std::atomic inflight_count_; + + // Fixed-size thread pool to run callbacks at end of each ensemble step. + // Managed by the server. + triton::common::ThreadPool* callback_pool_; }; }} // namespace triton::core diff --git a/src/server.cc b/src/server.cc index 68b39954f..fecc5d4a8 100644 --- a/src/server.cc +++ b/src/server.cc @@ -1,4 +1,4 @@ -// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -117,6 +117,13 @@ InferenceServer::InferenceServer() #endif // TRITON_ENABLE_GPU inflight_request_counter_ = 0; + +#ifdef TRITON_ENABLE_ENSEMBLE + // TODO: Need to scale the thread pool size smarter, e.g. based on the + // instance_group count of composing models. + ensemble_cb_pool_.reset( + new triton::common::ThreadPool(ENSEMBLE_CB_POOL_SIZE)); +#endif } Status diff --git a/src/server.h b/src/server.h index 5c67c6381..8fa4846a5 100644 --- a/src/server.h +++ b/src/server.h @@ -1,4 +1,4 @@ -// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -332,6 +332,13 @@ class InferenceServer { return cache_manager_; } +#ifdef TRITON_ENABLE_ENSEMBLE + triton::common::ThreadPool* EnsembleCallbackPool() const + { + return ensemble_cb_pool_.get(); + } +#endif // TRITON_ENABLE_ENSEMBLE + private: const std::string version_; std::string id_; @@ -375,6 +382,12 @@ class InferenceServer { std::unique_ptr model_repository_manager_; std::shared_ptr backend_manager_; std::shared_ptr cache_manager_; + +#ifdef TRITON_ENABLE_ENSEMBLE + // The thread pool for all ensemble models to execute callbacks + // asynchronously. + std::unique_ptr ensemble_cb_pool_; +#endif // TRITON_ENABLE_ENSEMBLE }; }} // namespace triton::core