Skip to content

feat: Ensemble async callback execution (rework) #438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ constexpr char kPythonBackend[] = "python";

#ifdef TRITON_ENABLE_ENSEMBLE
constexpr char kEnsemblePlatform[] = "ensemble";
constexpr uint64_t ENSEMBLE_CB_POOL_SIZE = 8u;
#endif // TRITON_ENABLE_ENSEMBLE

constexpr char kTensorRTExecutionAccelerator[] = "tensorrt";
Expand Down
124 changes: 96 additions & 28 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -45,17 +45,45 @@ class EnsembleContext;

using IterationCount = size_t;

// Check if the model is configured to preserve the order of responses.
// This is critical for async execution of ResponseComplete callbacks.
inline bool
preserve_responses_order(const inference::ModelConfig& config)
{
uint64_t total_instance_groups = 0;
for (const auto& group : config.instance_group()) {
total_instance_groups += group.count();
}

// Case 1: Sequence batching is enabled
// Case 2: Dynamic batching is disabled and there is only one instance group
// Case 3: Dynamic batching is enabled and preserve_ordering is true
// Case 4: Model transaction policy is decoupled (breaks RequestTracker
// lifecycle)
// Note: Although decoupled models do not preserve the order of
// responses, if the final response callback is not executed in the last step,
// the RequestTracker object will be freed prematurely and led to segmentation
// fault.
return config.has_sequence_batching() ||
(!config.has_dynamic_batching() && total_instance_groups <= 1) ||
(config.has_dynamic_batching() &&
config.dynamic_batching().preserve_ordering()) ||
config.model_transaction_policy().decoupled();
}

// Request tracker is passed as 'userp' in RequestRelease function and used
// to manage the lifecycle of the ensemble request
class RequestTracker {
public:
explicit RequestTracker(
std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns,
MetricModelReporter* metric_reporter,
InferenceStatsAggregator* stats_aggregator)
InferenceStatsAggregator* stats_aggregator,
triton::common::ThreadPool* callback_pool)
: inflight_request_counter_(1), request_(std::move(request)),
compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter),
stats_aggregator_(stats_aggregator), status_(Status::Success)
stats_aggregator_(stats_aggregator), status_(Status::Success),
callback_pool_(callback_pool)
{
}

Expand All @@ -70,6 +98,8 @@ class RequestTracker {
return context_stats_aggregator_;
}

triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }

void IncrementCounter()
{
std::lock_guard<std::mutex> lk(mtx_);
Expand Down Expand Up @@ -120,6 +150,7 @@ class RequestTracker {
InferenceStatsAggregator* stats_aggregator_;
InferenceStatsAggregator context_stats_aggregator_;
Status status_;
triton::common::ThreadPool* const callback_pool_;
};

// Step is used as 'userp' and keeps ensemble context alive
Expand All @@ -129,9 +160,9 @@ class RequestTracker {
struct Step {
Step(
size_t step_idx, const InferenceRequest::SequenceId& correlation_id,
uint32_t flags)
uint32_t flags, bool preserve_responses_order)
: correlation_id_(correlation_id), flags_(flags), response_flags_(0),
step_idx_(step_idx)
preserve_responses_order_(preserve_responses_order), step_idx_(step_idx)
{
}

Expand All @@ -154,7 +185,7 @@ struct Step {
// returning from the callback.
uint32_t response_flags_;
TRITONSERVER_InferenceResponse* response_;

const bool preserve_responses_order_;

size_t step_idx_;
};
Expand Down Expand Up @@ -237,7 +268,7 @@ class EnsembleContext {
MetricModelReporter* metric_reporter,
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
cudaStream_t stream);
cudaStream_t stream, triton::common::ThreadPool* callback_pool);

// Perform transition on 'context' state given the information of
// 'completed_step'
Expand Down Expand Up @@ -326,6 +357,8 @@ class EnsembleContext {
void CacheEnsembleTopLevelRequest(
std::unique_ptr<InferenceResponse>& response);

triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }

InferenceServer* is_;

EnsembleInfo* info_;
Expand Down Expand Up @@ -375,20 +408,26 @@ class EnsembleContext {
TRITONSERVER_ResponseAllocator,
decltype(&TRITONSERVER_ResponseAllocatorDelete)>
allocator_;

// The thread pool used to execute ensemble callbacks and reduce e2e latency.
// The thread pool is managed by InferenceServer.
triton::common::ThreadPool* const callback_pool_;
};

EnsembleContext::EnsembleContext(
MetricModelReporter* metric_reporter,
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
cudaStream_t stream)
cudaStream_t stream, triton::common::ThreadPool* callback_pool)
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0),
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete)
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete),
callback_pool_(callback_pool)
{
uint64_t compute_start_ns = 0;
INFER_STATS_SET_TIMESTAMP(compute_start_ns);
request_tracker_ = new RequestTracker(
std::move(request), compute_start_ns, metric_reporter, stats_aggregator);
std::move(request), compute_start_ns, metric_reporter, stats_aggregator,
callback_pool);

auto& lrequest = request_tracker_->Request();

Expand Down Expand Up @@ -603,29 +642,57 @@ void
EnsembleContext::RequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
{
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting ensemble inference request");
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
if (request_tracker->DecrementCounter()) {
delete request_tracker;
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
auto pool = request_tracker->CallbackPool();
auto fn = [request, flags, request_tracker]() {
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting ensemble inference request");
if (request_tracker->DecrementCounter()) {
delete request_tracker;
}
}
};

// Attempt to enqueue the callback. If all workers are busy and queue is at
// capacity, execute the callback immediately in current thread.
if (pool->TaskQueueSize() < pool->Size()) {
pool->Enqueue(fn);
} else {
fn();
}
}

void
EnsembleContext::ResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp)
{
auto step_ptr = std::unique_ptr<Step>(reinterpret_cast<Step*>(userp));
step_ptr->response_flags_ = flags;
step_ptr->response_ = response;

EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
// Expecting more responses
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
step_ptr.release();
auto step_raw_ptr = reinterpret_cast<Step*>(userp);
auto pool = step_raw_ptr->ctx_->CallbackPool();
auto fn = [response, flags, step_raw_ptr]() {
auto step_ptr = std::unique_ptr<Step>(step_raw_ptr);
step_ptr->response_flags_ = flags;
step_ptr->response_ = response;

EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
// Expecting more responses
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
step_ptr.release();
}
};

// Attempt to enqueue the callback. If all workers are busy and queue is at
// capacity, execute the callback immediately in current thread.
// Note: The async callback optimization does not guarantee the order of
// responses and expolit cases where responses can be out-of-order. For models
// required to preserve the order of responses, the response callbacks must be
// executed in the same thread synchronously.
if (!step_raw_ptr->preserve_responses_order_ &&
pool->TaskQueueSize() < pool->Size()) {
pool->Enqueue(fn);
} else {
fn();
}
}

Expand Down Expand Up @@ -971,8 +1038,8 @@ EnsembleContext::InitStep(
for (const auto& pair : istep.output_to_tensor_) {
irequest->AddOriginalRequestedOutput(pair.first);
}

step->reset(new Step(step_idx, correlation_id, flags));
const bool preserve_order = preserve_responses_order(model->Config());
step->reset(new Step(step_idx, correlation_id, flags, preserve_order));

irequest->SetId(request_id_);
irequest->SetCorrelationId(correlation_id);
Expand Down Expand Up @@ -1448,7 +1515,7 @@ EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
RETURN_IF_ERROR(request->SetState(InferenceRequest::State::EXECUTING));
std::shared_ptr<EnsembleContext> context(new EnsembleContext(
metric_reporter_.get(), stats_aggregator_, is_, info_.get(), request,
stream_));
stream_, callback_pool_));
EnsembleContext::Proceed(context);
return Status::Success;
}
Expand Down Expand Up @@ -1537,6 +1604,7 @@ EnsembleScheduler::EnsembleScheduler(
info_->tensor_to_prev_step_.emplace(pair.second, step_idx);
}
}
callback_pool_ = is_->EnsembleCallbackPool();
}

EnsembleScheduler::~EnsembleScheduler()
Expand Down
9 changes: 8 additions & 1 deletion src/ensemble_scheduler/ensemble_scheduler.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -36,6 +36,7 @@
#include "scheduler.h"
#include "scheduler_utils.h"
#include "status.h"
#include "triton/common/thread_pool.h"

#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
Expand Down Expand Up @@ -107,6 +108,8 @@ class EnsembleScheduler : public Scheduler {
// \see Scheduler::Stop()
void Stop() override {}

triton::common::ThreadPool* CallbackPool() const { return callback_pool_; }

private:
EnsembleScheduler(
InferenceStatsAggregator* const stats_aggregator,
Expand All @@ -128,6 +131,10 @@ class EnsembleScheduler : public Scheduler {
cudaStream_t stream_;

std::atomic<size_t> inflight_count_;

// Fixed-size thread pool to run callbacks at end of each ensemble step.
// Managed by the server.
triton::common::ThreadPool* callback_pool_;
};

}} // namespace triton::core
Expand Down
9 changes: 8 additions & 1 deletion src/server.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -117,6 +117,13 @@ InferenceServer::InferenceServer()
#endif // TRITON_ENABLE_GPU

inflight_request_counter_ = 0;

#ifdef TRITON_ENABLE_ENSEMBLE
// TODO: Need to scale the thread pool size smarter, e.g. based on the
// instance_group count of composing models.
ensemble_cb_pool_.reset(
new triton::common::ThreadPool(ENSEMBLE_CB_POOL_SIZE));
#endif
}

Status
Expand Down
15 changes: 14 additions & 1 deletion src/server.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -332,6 +332,13 @@ class InferenceServer {
return cache_manager_;
}

#ifdef TRITON_ENABLE_ENSEMBLE
triton::common::ThreadPool* EnsembleCallbackPool() const
{
return ensemble_cb_pool_.get();
}
#endif // TRITON_ENABLE_ENSEMBLE

private:
const std::string version_;
std::string id_;
Expand Down Expand Up @@ -375,6 +382,12 @@ class InferenceServer {
std::unique_ptr<ModelRepositoryManager> model_repository_manager_;
std::shared_ptr<TritonBackendManager> backend_manager_;
std::shared_ptr<TritonCacheManager> cache_manager_;

#ifdef TRITON_ENABLE_ENSEMBLE
// The thread pool for all ensemble models to execute callbacks
// asynchronously.
std::unique_ptr<triton::common::ThreadPool> ensemble_cb_pool_;
#endif // TRITON_ENABLE_ENSEMBLE
};

}} // namespace triton::core