Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: triton generate support #675

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nnshah1 what's blocking this PR as being marked ready for review?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its out of date from some of our other work now and needs to be ported to the new repository at the bare minimum. I really like the additions here so I would like to see them integrated soon as well.

Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ repos:
- id: isort
additional_dependencies: [toml]
- repo: https://github.com/psf/black
rev: 24.4.0
rev: 23.1.0
hooks:
- id: black
types_or: [python, cython]
Expand Down
3 changes: 0 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ project(tritonclient LANGUAGES C CXX)
# Use C++17 standard as Triton's minimum required.
set(TRITON_MIN_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard which features are requested to build this target.")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

#
# Options
#
Expand Down
6 changes: 0 additions & 6 deletions src/c++/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ cmake_minimum_required(VERSION 3.17)

project(cc-clients LANGUAGES C CXX)

# Use C++17 standard as Triton's minimum required.
set(TRITON_MIN_CXX_STANDARD 17 CACHE STRING "The minimum C++ standard which features are requested to build this target.")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

#
# Options
#
Expand Down
194 changes: 90 additions & 104 deletions src/c++/library/http_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1371,23 +1371,27 @@ InferenceServerHttpClient::InferenceServerHttpClient(

InferenceServerHttpClient::~InferenceServerHttpClient()
{
{
std::lock_guard<std::mutex> lock(mutex_);
exiting_ = true;
}

curl_multi_wakeup(multi_handle_);
exiting_ = true;

// thread not joinable if AsyncInfer() is not called
// (it is default constructed thread before the first AsyncInfer() call)
if (worker_.joinable()) {
cv_.notify_all();
worker_.join();
}

if (easy_handle_ != nullptr) {
curl_easy_cleanup(reinterpret_cast<CURL*>(easy_handle_));
}
curl_multi_cleanup(multi_handle_);

if (multi_handle_ != nullptr) {
for (auto& request : ongoing_async_requests_) {
CURL* easy_handle = reinterpret_cast<CURL*>(request.first);
curl_multi_remove_handle(multi_handle_, easy_handle);
curl_easy_cleanup(easy_handle);
}
curl_multi_cleanup(multi_handle_);
}
}

Error
Expand Down Expand Up @@ -1883,28 +1887,25 @@ InferenceServerHttpClient::AsyncInfer(
{
std::lock_guard<std::mutex> lock(mutex_);

if (exiting_) {
return Error("Client is exiting.");
}

auto insert_result = new_async_requests_.emplace(std::make_pair(
auto insert_result = ongoing_async_requests_.emplace(std::make_pair(
reinterpret_cast<uintptr_t>(multi_easy_handle), async_request));
if (!insert_result.second) {
curl_easy_cleanup(multi_easy_handle);
return Error("Failed to insert new asynchronous request context.");
}
}

async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_START);
curl_multi_wakeup(multi_handle_);
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_START);
if (async_request->total_input_byte_size_ == 0) {
// Set SEND_END here because CURLOPT_READFUNCTION will not be called if
// content length is 0. In that case, we can't measure SEND_END properly
// (send ends after sending request header).
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_END);
}

if (async_request->total_input_byte_size_ == 0) {
// Set SEND_END here because CURLOPT_READFUNCTION will not be called if
// content length is 0. In that case, we can't measure SEND_END properly
// (send ends after sending request header).
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_END);
curl_multi_add_handle(multi_handle_, multi_easy_handle);
}

cv_.notify_all();
return Error::Success;
}

Expand Down Expand Up @@ -2248,103 +2249,88 @@ InferenceServerHttpClient::PreRunProcessing(
void
InferenceServerHttpClient::AsyncTransfer()
{
int messages_in_queue = 0;
int still_running = 0;
int numfds = 0;
int place_holder = 0;
CURLMsg* msg = nullptr;
AsyncReqMap ongoing_async_requests;
do {
// Check for new requests and add them to ongoing requests
{
std::lock_guard<std::mutex> lock(mutex_);

for (auto& pair : new_async_requests_) {
curl_multi_add_handle(
multi_handle_, reinterpret_cast<CURL*>(pair.first));
std::vector<std::shared_ptr<HttpInferRequest>> request_list;

ongoing_async_requests[pair.first] = std::move(pair.second);
// sleep if no work is available
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
if (this->exiting_) {
return true;
}
new_async_requests_.clear();
}

CURLMcode mc = curl_multi_perform(multi_handle_, &still_running);

if (mc != CURLM_OK) {
std::cerr << "Unexpected error: curl_multi failed. Code:" << mc
// wake up if an async request has been generated
return !this->ongoing_async_requests_.empty();
});

CURLMcode mc = curl_multi_perform(multi_handle_, &place_holder);
int numfds;
if (mc == CURLM_OK) {
// Wait for activity. If there are no descriptors in the multi_handle_
// then curl_multi_wait will return immediately
mc = curl_multi_wait(multi_handle_, NULL, 0, INT_MAX, &numfds);
if (mc == CURLM_OK) {
while ((msg = curl_multi_info_read(multi_handle_, &place_holder))) {
uintptr_t identifier = reinterpret_cast<uintptr_t>(msg->easy_handle);
auto itr = ongoing_async_requests_.find(identifier);
// This shouldn't happen
if (itr == ongoing_async_requests_.end()) {
std::cerr
<< "Unexpected error: received completed request that is not "
"in the list of asynchronous requests"
<< std::endl;
continue;
}
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);
continue;
}

while ((msg = curl_multi_info_read(multi_handle_, &messages_in_queue))) {
if (msg->msg != CURLMSG_DONE) {
// Something wrong happened.
std::cerr << "Unexpected error: received CURLMsg=" << msg->msg
<< std::endl;
continue;
}
long http_code = 400;
if (msg->data.result == CURLE_OK) {
curl_easy_getinfo(
msg->easy_handle, CURLINFO_RESPONSE_CODE, &http_code);
} else if (msg->data.result == CURLE_OPERATION_TIMEDOUT) {
http_code = 499;
}

uintptr_t identifier = reinterpret_cast<uintptr_t>(msg->easy_handle);
auto itr = ongoing_async_requests.find(identifier);
// This shouldn't happen
if (itr == ongoing_async_requests.end()) {
std::cerr << "Unexpected error: received completed request that is not "
"in the list of asynchronous requests"
<< std::endl;
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);
continue;
}
auto async_request = itr->second;

uint32_t http_code = 400;
if (msg->data.result == CURLE_OK) {
curl_easy_getinfo(msg->easy_handle, CURLINFO_RESPONSE_CODE, &http_code);
async_request->Timer().CaptureTimestamp(
RequestTimers::Kind::REQUEST_END);
Error err = UpdateInferStat(async_request->Timer());
if (!err.IsOk()) {
std::cerr << "Failed to update context stat: " << err << std::endl;
request_list.emplace_back(itr->second);
ongoing_async_requests_.erase(itr);
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);

std::shared_ptr<HttpInferRequest> async_request = request_list.back();
async_request->http_code_ = http_code;

if (msg->msg != CURLMSG_DONE) {
// Something wrong happened.
std::cerr << "Unexpected error: received CURLMsg=" << msg->msg
<< std::endl;
} else {
async_request->Timer().CaptureTimestamp(
RequestTimers::Kind::REQUEST_END);
Error err = UpdateInferStat(async_request->Timer());
if (!err.IsOk()) {
std::cerr << "Failed to update context stat: " << err
<< std::endl;
}
}
}
} else if (msg->data.result == CURLE_OPERATION_TIMEDOUT) {
http_code = 499;
} else {
std::cerr << "Unexpected error: curl_multi failed. Code:" << mc
<< std::endl;
}

async_request->http_code_ = http_code;
InferResult* result;
InferResultHttp::Create(&result, async_request);
async_request->callback_(result);
ongoing_async_requests.erase(itr);
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);
} else {
std::cerr << "Unexpected error: curl_multi failed. Code:" << mc
<< std::endl;
}
lock.unlock();

// Wait for activity on existing requests or
// explicit curl_multi_wakeup call
//
// If there are no descriptors in the multi_handle_
// then curl_multi_poll will wait until curl_multi_wakeup
// is called
//
// curl_multi_wakeup is called when adding a new request
// or exiting

mc = curl_multi_poll(multi_handle_, NULL, 0, INT_MAX, &numfds);
if (mc != CURLM_OK) {
std::cerr << "Unexpected error: curl_multi_poll failed. Code:" << mc
<< std::endl;
for (auto& this_request : request_list) {
InferResult* result;
InferResultHttp::Create(&result, this_request);
this_request->callback_(result);
}
} while (!exiting_);

for (auto& request : ongoing_async_requests) {
CURL* easy_handle = reinterpret_cast<CURL*>(request.first);
curl_multi_remove_handle(multi_handle_, easy_handle);
curl_easy_cleanup(easy_handle);
}

for (auto& request : new_async_requests_) {
CURL* easy_handle = reinterpret_cast<CURL*>(request.first);
curl_easy_cleanup(easy_handle);
}
}

size_t
Expand Down
4 changes: 2 additions & 2 deletions src/c++/library/http_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,9 @@ class InferenceServerHttpClient : public InferenceServerClient {
void* easy_handle_;
// curl multi handle for processing asynchronous requests
void* multi_handle_;
// map to record new asynchronous requests with pointer to easy handle
// map to record ongoing asynchronous requests with pointer to easy handle
// or tag id as key
AsyncReqMap new_async_requests_;
AsyncReqMap ongoing_async_requests_;
};

}} // namespace triton::client
Loading
Loading