Skip to content

Commit

Permalink
Only fast exit for non-shm cases
Browse files Browse the repository at this point in the history
  • Loading branch information
tgerdesnv committed May 30, 2024
1 parent 3604a7d commit aac7521
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/c++/perf_analyzer/infer_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ InferContext::AsyncCallbackFuncImpl(cb::InferResult* result)
// Add the request record to thread request records vector with
// proper locking
std::lock_guard<std::mutex> lock(thread_stat_->mu_);
if (exiting_) {
if (exiting_ && fast_exit_) {
return;
}

Expand Down
7 changes: 6 additions & 1 deletion src/c++/perf_analyzer/infer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ class InferContext {
void Init();

// Signal to the context to stop working and exit
void Exit() { exiting_ = true; }
void Exit(bool fast_exit)
{
exiting_ = true;
fast_exit_ = fast_exit;
}

// Send a single inference request to the server
void SendInferRequest(bool delayed = false);
Expand Down Expand Up @@ -196,6 +200,7 @@ class InferContext {
const uint32_t id_{0};
const size_t thread_id_{0};
bool exiting_{false};
bool fast_exit_{false};

size_t GetNumActiveThreads() { return num_active_threads_; }

Expand Down
2 changes: 1 addition & 1 deletion src/c++/perf_analyzer/iworker.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace triton { namespace perfanalyzer {
class IWorker {
public:
virtual void Infer() = 0;
virtual void Exit() = 0;
virtual void Exit(bool fast_exit) = 0;
};

}} // namespace triton::perfanalyzer
8 changes: 5 additions & 3 deletions src/c++/perf_analyzer/load_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ LoadManager::LoadManager(
const std::unordered_map<std::string, cb::RequestParameter>&
request_parameters)
: async_(async), streaming_(streaming), batch_size_(batch_size),
max_threads_(max_threads), parser_(parser), factory_(factory),
using_json_data_(false)
max_threads_(max_threads), shared_memory_type_{shared_memory_type},
parser_(parser), factory_(factory), using_json_data_(false)
{
on_sequence_model_ =
((parser_->SchedulerType() == ModelParser::SEQUENCE) ||
Expand Down Expand Up @@ -248,9 +248,11 @@ LoadManager::InitManagerInputs(
void
LoadManager::StopWorkerThreads()
{
bool fast_exit = shared_memory_type_ == SharedMemoryType::NO_SHARED_MEMORY;

// FIXME do I need to acquire the lock first?
for (auto& worker : workers_) {
worker->Exit();
worker->Exit(fast_exit);
}

{
Expand Down
1 change: 1 addition & 0 deletions src/c++/perf_analyzer/load_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class LoadManager {
size_t batch_size_;
size_t max_threads_;
bool on_sequence_model_;
SharedMemoryType shared_memory_type_;

std::shared_ptr<ModelParser> parser_;
std::shared_ptr<cb::ClientBackendFactory> factory_;
Expand Down
10 changes: 7 additions & 3 deletions src/c++/perf_analyzer/load_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@
namespace triton { namespace perfanalyzer {

void
LoadWorker::Exit()
LoadWorker::Exit(bool fast_exit)
{
for (auto ctx : ctxs_) {
ctx->Exit();
ctx->Exit(fast_exit);
}

exiting_ = true;
fast_exit_ = fast_exit;

{
std::lock_guard<std::mutex> lk(cb_mtx_);
Expand All @@ -67,6 +68,9 @@ LoadWorker::HandleExitConditions()
{
if (ShouldExit()) {
CompleteOngoingSequences();
if (!fast_exit_) {
WaitForOngoingRequests();
}
return true;
}
return false;
Expand All @@ -86,7 +90,7 @@ LoadWorker::CompleteOngoingSequences()
void
LoadWorker::WaitForOngoingRequests()
{
while (GetNumOngoingRequests() != 0) {
while (GetNumOngoingRequests() != 0 && !fast_exit_) {
std::this_thread::sleep_for(std::chrono::milliseconds(50));
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/c++/perf_analyzer/load_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class LoadWorker : public IWorker {

virtual ~LoadWorker() = default;

virtual void Exit() override;
virtual void Exit(bool fast_exit) override;

protected:
// Return the total number of async requests that have started and not
Expand Down Expand Up @@ -120,6 +120,7 @@ class LoadWorker : public IWorker {
void AsyncCallbackFinalize(uint32_t ctx_id);

bool exiting_ = false;
bool fast_exit_ = false;

uint32_t id_;

Expand Down
4 changes: 2 additions & 2 deletions src/c++/perf_analyzer/test_concurrency_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ TEST_CASE("concurrency_free_ctx_ids")

std::this_thread::sleep_for(std::chrono::milliseconds(15));

worker->Exit();
worker->Exit(false);
infer_future.get();

// The first sequence should only be called two times, once at the very start,
Expand Down Expand Up @@ -590,7 +590,7 @@ TEST_CASE("Concurrency - shared memory infer input calls")

std::this_thread::sleep_for(std::chrono::milliseconds(18));

worker->Exit();
worker->Exit(false);
infer_future.get();

const auto& actual_append_raw_calls{tcm.stats_->num_append_raw_calls};
Expand Down
4 changes: 2 additions & 2 deletions src/c++/perf_analyzer/test_request_rate_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ TEST_CASE("request_rate_streaming: test that streaming-specific logic works")
std::dynamic_pointer_cast<IScheduler>(worker)->SetSchedule(schedule);
std::future<void> infer_future{std::async(&IWorker::Infer, worker)};

worker->Exit();
worker->Exit(false);
infer_future.get();

CHECK(
Expand Down Expand Up @@ -1825,7 +1825,7 @@ TEST_CASE("Request rate - Shared memory infer input calls")

std::this_thread::sleep_for(milliseconds(18));

worker->Exit();
worker->Exit(false);
infer_future.get();

const auto& actual_append_raw_calls{trrm.stats_->num_append_raw_calls};
Expand Down

0 comments on commit aac7521

Please sign in to comment.