Skip to content

Commit

Permalink
Add retry on model loading. Expose option to set model retry count (#308
Browse files Browse the repository at this point in the history
)

* Group model repository files

* Expose option to set model retry count
  • Loading branch information
GuanLuo authored Jan 5, 2024
1 parent 82d3371 commit 3b97b2f
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 32 deletions.
11 changes: 10 additions & 1 deletion include/triton/core/tritonserver.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct TRITONSERVER_MetricFamily;
/// }
///
#define TRITONSERVER_API_VERSION_MAJOR 1
#define TRITONSERVER_API_VERSION_MINOR 27
#define TRITONSERVER_API_VERSION_MINOR 28

/// Get the TRITONBACKEND API version supported by the Triton shared
/// library. This value can be compared against the
Expand Down Expand Up @@ -1978,6 +1978,15 @@ TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
TRITONSERVER_ServerOptionsSetModelLoadThreadCount(
struct TRITONSERVER_ServerOptions* options, unsigned int thread_count);

/// Set the number of retry to load a model in a server options.
///
/// \param options The server options object.
/// \param retry_count The number of retry.
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
TRITONSERVER_ServerOptionsSetModelLoadRetryCount(
struct TRITONSERVER_ServerOptions* options, unsigned int retry_count);

/// Enable model namespacing to allow serving models with the same name if
/// they are in different namespaces.
///
Expand Down
9 changes: 9 additions & 0 deletions python/tritonserver/_c/tritonserver_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,12 @@ class PyServerOptions : public PyWrapper<struct TRITONSERVER_ServerOptions> {
triton_object_, thread_count));
}

void SetModelLoadRetryCount(unsigned int retry_count)
{
ThrowIfError(TRITONSERVER_ServerOptionsSetModelLoadRetryCount(
triton_object_, retry_count));
}

void SetModelNamespacing(bool enable_namespace)
{
ThrowIfError(TRITONSERVER_ServerOptionsSetModelNamespacing(
Expand Down Expand Up @@ -2017,6 +2023,9 @@ PYBIND11_MODULE(triton_bindings, m)
.def(
"set_model_load_thread_count",
&PyServerOptions::SetModelLoadThreadCount)
.def(
"set_model_load_retry_count",
&PyServerOptions::SetModelLoadRetryCount)
.def("set_model_namespacing", &PyServerOptions::SetModelNamespacing)
.def("set_log_file", &PyServerOptions::SetLogFile)
.def("set_log_info", &PyServerOptions::SetLogInfo)
Expand Down
8 changes: 4 additions & 4 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ set(
metric_family.cc
model.cc
model_config_utils.cc
model_lifecycle.cc
model_repository_manager.cc
model_repository_manager/model_lifecycle.cc
model_repository_manager/model_repository_manager.cc
numa_utils.cc
payload.cc
pinned_memory_manager.cc
Expand Down Expand Up @@ -187,8 +187,8 @@ set(
metric_family.h
model_config_utils.h
model.h
model_lifecycle.h
model_repository_manager.h
model_repository_manager/model_lifecycle.h
model_repository_manager/model_repository_manager.h
numa_utils.h
payload.h
pinned_memory_manager.h
Expand Down
2 changes: 1 addition & 1 deletion src/ensemble_scheduler/ensemble_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <unordered_map>

#include "model_config.pb.h"
#include "model_repository_manager.h"
#include "model_repository_manager/model_repository_manager.h"
#include "status.h"
#include "triton/common/model_config.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,15 @@ ModelLifeCycle::AsyncLoad(
// Load model asynchronously via thread pool
load_pool_->Enqueue([this, model_id, version, model_info, OnComplete,
load_tracker, is_config_provided]() {
CreateModel(model_id, version, model_info, is_config_provided);
for (size_t retry = 0; retry <= options_.load_retry; ++retry) {
model_info->state_ = ModelReadyState::LOADING;
CreateModel(model_id, version, model_info, is_config_provided);
// Model state will be changed to NOT loading if failed to load,
// so the model is loaded if state is LOADING.
if (model_info->state_ == ModelReadyState::LOADING) {
break;
}
}
OnLoadComplete(
model_id, version, model_info, false /* is_update */, OnComplete,
load_tracker);
Expand All @@ -540,15 +548,16 @@ ModelLifeCycle::CreateModel(
if (!model_config.backend().empty()) {
std::unique_ptr<TritonModel> model;
status = TritonModel::Create(
server_, model_info->model_path_, cmdline_config_map_, host_policy_map_,
version, model_config, is_config_provided, &model);
server_, model_info->model_path_, options_.backend_cmdline_config_map,
options_.host_policy_map, version, model_config, is_config_provided,
&model);
is.reset(model.release());
} else {
#ifdef TRITON_ENABLE_ENSEMBLE
if (model_info->is_ensemble_) {
status = EnsembleModel::Create(
server_, model_info->model_path_, version, model_config,
is_config_provided, min_compute_capability_, &is);
is_config_provided, options_.min_compute_capability, &is);
// Complete label provider with label information from involved models
// Must be done here because involved models may not be able to
// obtained from server because this may happen during server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,23 @@ struct ModelLifeCycleOptions {
const double min_compute_capability,
const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map,
const triton::common::HostPolicyCmdlineConfigMap& host_policy_map,
const unsigned int model_load_thread_count)
: min_compute_capability_(min_compute_capability),
backend_cmdline_config_map_(backend_cmdline_config_map),
host_policy_map_(host_policy_map),
model_load_thread_count_(model_load_thread_count)
const unsigned int model_load_thread_count, const size_t load_retry)
: min_compute_capability(min_compute_capability),
backend_cmdline_config_map(backend_cmdline_config_map),
host_policy_map(host_policy_map),
model_load_thread_count(model_load_thread_count), load_retry(load_retry)
{
}
// The minimum supported CUDA compute capability.
const double min_compute_capability_;
const double min_compute_capability;
// The backend configuration settings specified on the command-line
const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map_;
const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map;
// The host policy setting used when loading models.
const triton::common::HostPolicyCmdlineConfigMap& host_policy_map_;
const triton::common::HostPolicyCmdlineConfigMap& host_policy_map;
// Number of the threads to use for concurrently loading models
const unsigned int model_load_thread_count_;
const unsigned int model_load_thread_count;
// Number of retry on model loading before considering the load has failed.
const size_t load_retry{0};
};


Expand Down Expand Up @@ -283,13 +285,10 @@ class ModelLifeCycle {
};

ModelLifeCycle(InferenceServer* server, const ModelLifeCycleOptions& options)
: server_(server),
min_compute_capability_(options.min_compute_capability_),
cmdline_config_map_(options.backend_cmdline_config_map_),
host_policy_map_(options.host_policy_map_)
: server_(server), options_(options)
{
load_pool_.reset(new triton::common::ThreadPool(
std::max(1u, options.model_load_thread_count_)));
std::max(1u, options_.model_load_thread_count)));
}

// Create a new model, the 'model_id' can either be a new or existing model.
Expand Down Expand Up @@ -327,9 +326,7 @@ class ModelLifeCycle {
std::map<uintptr_t, std::unique_ptr<ModelInfo>> background_models_;

InferenceServer* server_;
const double min_compute_capability_;
const triton::common::BackendCmdlineConfigMap cmdline_config_map_;
const triton::common::HostPolicyCmdlineConfigMap host_policy_map_;
const ModelLifeCycleOptions options_;

// Fixed-size thread pool to load models at specified concurrency
std::unique_ptr<triton::common::ThreadPool> load_pool_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ ModelRepositoryManager::Create(
std::unique_ptr<ModelRepositoryManager> local_manager(
new ModelRepositoryManager(
repository_paths, !strict_model_config, polling_enabled,
model_control_enabled, life_cycle_options.min_compute_capability_,
model_control_enabled, life_cycle_options.min_compute_capability,
enable_model_namespacing, std::move(life_cycle)));
*model_repository_manager = std::move(local_manager);

Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions src/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
#include "model.h"
#include "model_config.pb.h"
#include "model_config_utils.h"
#include "model_repository_manager.h"
#include "pinned_memory_manager.h"
#include "repo_agent.h"
#include "triton/common/async_work_queue.h"
Expand Down Expand Up @@ -108,6 +107,7 @@ InferenceServer::InferenceServer()
pinned_memory_pool_size_ = 1 << 28;
buffer_manager_thread_count_ = 0;
model_load_thread_count_ = 4;
model_load_retry_count_ = 0;
enable_model_namespacing_ = false;

#ifdef TRITON_ENABLE_GPU
Expand Down Expand Up @@ -258,7 +258,7 @@ InferenceServer::Init()
(model_control_mode_ == ModelControlMode::MODE_EXPLICIT);
const ModelLifeCycleOptions life_cycle_options(
min_supported_compute_capability_, backend_cmdline_config_map_,
host_policy_map_, model_load_thread_count_);
host_policy_map_, model_load_thread_count_, model_load_retry_count_);
status = ModelRepositoryManager::Create(
this, version_, model_repository_paths_, startup_models_,
strict_model_config_, polling_enabled, model_control_enabled,
Expand Down
5 changes: 4 additions & 1 deletion src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include "cache_manager.h"
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "model_repository_manager.h"
#include "model_repository_manager/model_repository_manager.h"
#include "rate_limiter.h"
#include "status.h"
#include "triton/common/model_config.h"
Expand Down Expand Up @@ -257,6 +257,8 @@ class InferenceServer {

void SetModelLoadThreadCount(unsigned int c) { model_load_thread_count_ = c; }

void SetModelLoadRetryCount(unsigned int c) { model_load_retry_count_ = c; }

void SetModelNamespacingEnabled(const bool e)
{
enable_model_namespacing_ = e;
Expand Down Expand Up @@ -334,6 +336,7 @@ class InferenceServer {
uint32_t exit_timeout_secs_;
uint32_t buffer_manager_thread_count_;
uint32_t model_load_thread_count_;
uint32_t model_load_retry_count_;
bool enable_model_namespacing_;
uint64_t pinned_memory_pool_size_;
bool response_cache_enabled_;
Expand Down
17 changes: 16 additions & 1 deletion src/tritonserver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#include "metrics.h"
#include "model.h"
#include "model_config_utils.h"
#include "model_repository_manager.h"
#include "model_repository_manager/model_repository_manager.h"
#include "rate_limiter.h"
#include "response_allocator.h"
#include "server.h"
Expand Down Expand Up @@ -280,6 +280,9 @@ class TritonServerOptions {
unsigned int ModelLoadThreadCount() const { return model_load_thread_count_; }
void SetModelLoadThreadCount(unsigned int c) { model_load_thread_count_ = c; }

unsigned int ModelLoadRetryCount() const { return model_load_retry_count_; }
void SetModelLoadRetryCount(unsigned int c) { model_load_retry_count_ = c; }

bool ModelNamespacingEnabled() { return enable_model_namespacing_; }
void SetModelNamespacingEnabled(const bool e)
{
Expand Down Expand Up @@ -356,6 +359,7 @@ class TritonServerOptions {
uint64_t pinned_memory_pool_size_;
unsigned int buffer_manager_thread_count_;
unsigned int model_load_thread_count_;
unsigned int model_load_retry_count_;
bool enable_model_namespacing_;
std::map<int, uint64_t> cuda_memory_pool_size_;
double min_compute_capability_;
Expand Down Expand Up @@ -1355,6 +1359,16 @@ TRITONSERVER_ServerOptionsSetModelLoadThreadCount(
return nullptr; // Success
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONSERVER_ServerOptionsSetModelLoadRetryCount(
TRITONSERVER_ServerOptions* options, unsigned int retry_count)
{
TritonServerOptions* loptions =
reinterpret_cast<TritonServerOptions*>(options);
loptions->SetModelLoadRetryCount(retry_count);
return nullptr; // Success
}

TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONSERVER_ServerOptionsSetModelNamespacing(
TRITONSERVER_ServerOptions* options, bool enable_namespace)
Expand Down Expand Up @@ -2342,6 +2356,7 @@ TRITONSERVER_ServerNew(
lserver->SetRepoAgentDir(loptions->RepoAgentDir());
lserver->SetBufferManagerThreadCount(loptions->BufferManagerThreadCount());
lserver->SetModelLoadThreadCount(loptions->ModelLoadThreadCount());
lserver->SetModelLoadRetryCount(loptions->ModelLoadRetryCount());
lserver->SetModelNamespacingEnabled(loptions->ModelNamespacingEnabled());

// SetBackendCmdlineConfig must be called after all AddBackendConfig calls
Expand Down
4 changes: 4 additions & 0 deletions src/tritonserver_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,10 @@ TRITONSERVER_ServerOptionsSetModelLoadThreadCount()
{
}
TRITONAPI_DECLSPEC void
TRITONSERVER_ServerOptionsSetModelLoadRetryCount()
{
}
TRITONAPI_DECLSPEC void
TRITONSERVER_ServerOptionsSetModelNamespacing()
{
}
Expand Down

0 comments on commit 3b97b2f

Please sign in to comment.