Skip to content

Commit

Permalink
[Serve] MicroServing Implementation (#3064)
Browse files Browse the repository at this point in the history
This PR introduces MicroServing API.

MicroServing introduces simple yet effective REST APIs to support
fine-grained sub-request level actions.
A programmable router transforms user requests into sub-request calls,
lifting fine-grained scheduling to the API level, thus enabling the dynamic
reconfiguration of different orchestration patterns.

Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Charlie Ruan <[email protected]>
Co-authored-by: Yingcheng Wang <[email protected]>
  • Loading branch information
4 people authored Dec 16, 2024
1 parent 49dcd4a commit 5c9ebcb
Show file tree
Hide file tree
Showing 49 changed files with 2,497 additions and 198 deletions.
1 change: 1 addition & 0 deletions cpp/metadata/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata,
result.tensor_parallel_shards = json::Lookup<int64_t>(metadata, "tensor_parallel_shards");
result.pipeline_parallel_stages =
json::LookupOrDefault<int64_t>(metadata, "pipeline_parallel_stages", 1);
result.disaggregation = json::LookupOrDefault<bool>(metadata, "disaggregation", false);
result.kv_state_kind = KVStateKindFromString(
json::LookupOrDefault<std::string>(metadata, "kv_state_kind", "kv_cache"));
if (result.kv_state_kind != KVStateKind::kNone &&
Expand Down
1 change: 1 addition & 0 deletions cpp/metadata/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct ModelMetadata {
int64_t sliding_window_size;
int64_t tensor_parallel_shards;
int64_t pipeline_parallel_stages;
bool disaggregation;
int64_t attention_sink_size;
std::vector<Param> params;
std::unordered_map<std::string, int64_t> memory_usage;
Expand Down
113 changes: 111 additions & 2 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "../json_ffi/openai_api_protocol.h"
#include "../support/json_parser.h"
#include "../support/utils.h"
#include "data.h"

namespace mlc {
Expand Down Expand Up @@ -62,6 +63,105 @@ picojson::object ResponseFormat::AsJSON() const {
return config;
}

/****************** DisaggConfig ******************/

Result<DisaggConfig> DisaggConfig::FromJSON(const picojson::object& config) {
using TResult = Result<DisaggConfig>;
DisaggConfig res;
std::optional<std::string> kind = json::LookupOptional<std::string>(config, "kind");
if (kind.has_value()) {
if (kind.value() == "prepare_prefill") {
res.kind = DisaggRequestKind::kPreparePrefill;
} else if (kind.value() == "remote_prefill") {
res.kind = DisaggRequestKind::kRemotePrefill;
} else if (kind.value() == "start_decode") {
res.kind = DisaggRequestKind::kStartDecode;
} else {
return TResult::Error("Unknown disaggregation request kind " + kind.value());
}
}
std::optional<std::string> kv_append_metadata_encoded =
json::LookupOptional<std::string>(config, "kv_append_metadata");
if (kv_append_metadata_encoded.has_value()) {
picojson::value parse_result;
std::string err =
picojson::parse(parse_result, Base64Decode(kv_append_metadata_encoded.value()));
if (!err.empty()) {
return TResult::Error("kv_append_metadata parse error: " + err);
}
if (!parse_result.is<picojson::array>()) {
return TResult::Error("kv_append_metadata is not array of integer.");
}
picojson::array kv_append_metadata_arr = parse_result.get<picojson::array>();
std::vector<IntTuple> kv_append_metadata;
int ptr = 0;
while (ptr < static_cast<int>(kv_append_metadata_arr.size())) {
if (!kv_append_metadata_arr[ptr].is<int64_t>()) {
return TResult::Error("Invalid kv append metadata value in kv_append_metadata array");
}
int num_segments = kv_append_metadata_arr[ptr].get<int64_t>();
if (ptr + num_segments * 2 + 1 > static_cast<int>(kv_append_metadata_arr.size())) {
return TResult::Error("Invalid kv append metadata compression in kv_append_metadata");
}
std::vector<int64_t> compressed_kv_append_metadata{num_segments};
compressed_kv_append_metadata.reserve(num_segments * 2 + 1);
for (int i = 1; i <= num_segments * 2; ++i) {
if (!kv_append_metadata_arr[ptr + i].is<int64_t>()) {
return TResult::Error("Invalid kv append metadata value in kv_append_metadata array");
}
compressed_kv_append_metadata.push_back(kv_append_metadata_arr[ptr + i].get<int64_t>());
}
kv_append_metadata.push_back(IntTuple(std::move(compressed_kv_append_metadata)));
ptr += num_segments * 2 + 1;
}
res.kv_append_metadata = std::move(kv_append_metadata);
}
res.kv_window_begin = json::LookupOptional<int64_t>(config, "kv_window_begin");
res.kv_window_end = json::LookupOptional<int64_t>(config, "kv_window_end");
res.dst_group_offset = json::LookupOptional<int64_t>(config, "dst_group_offset");
return TResult::Ok(res);
}

picojson::object DisaggConfig::AsJSON() const {
picojson::object config;
switch (kind) {
case DisaggRequestKind::kPreparePrefill: {
config["kind"] = picojson::value("prepare_prefill");
break;
}
case DisaggRequestKind::kRemotePrefill: {
config["kind"] = picojson::value("remote_prefill");
break;
}
case DisaggRequestKind::kStartDecode: {
config["kind"] = picojson::value("start_decode");
break;
}
default:
break;
}
if (!kv_append_metadata.empty()) {
picojson::array kv_append_metadata_arr;
for (const IntTuple& compressed_kv_append_metadata : kv_append_metadata) {
for (int64_t value : compressed_kv_append_metadata) {
kv_append_metadata_arr.push_back(picojson::value(value));
}
}
config["kv_append_metadata"] =
picojson::value(Base64Encode(picojson::value(kv_append_metadata_arr).serialize()));
}
if (kv_window_begin.has_value()) {
config["kv_window_begin"] = picojson::value(static_cast<int64_t>(kv_window_begin.value()));
}
if (kv_window_end.has_value()) {
config["kv_window_end"] = picojson::value(static_cast<int64_t>(kv_window_end.value()));
}
if (dst_group_offset.has_value()) {
config["dst_group_offset"] = picojson::value(static_cast<int64_t>(dst_group_offset.value()));
}
return config;
}

/****************** DebugConfig ******************/

Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
Expand All @@ -74,7 +174,7 @@ Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
if (special_request == "query_engine_metrics") {
res.special_request = SpecialRequestKind::kQueryEngineMetrics;
} else {
return TResult::Error("Uknown special request " + special_request);
return TResult::Error("Unknown special request " + special_request);
}
}
std::string grammar_execution_mode =
Expand All @@ -84,8 +184,14 @@ Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
} else if (grammar_execution_mode == "constraint") {
res.grammar_execution_mode = GrammarExecutionMode::kConstraint;
} else {
return TResult::Error("Uknown grammar execution mode " + grammar_execution_mode);
return TResult::Error("Unknown grammar execution mode " + grammar_execution_mode);
}
Result<DisaggConfig> disagg_config =
DisaggConfig::FromJSON(json::Lookup<picojson::object>(config, "disagg_config"));
if (disagg_config.IsErr()) {
return TResult::Error(disagg_config.UnwrapErr());
}
res.disagg_config = disagg_config.Unwrap();
return TResult::Ok(res);
}

Expand Down Expand Up @@ -114,6 +220,9 @@ picojson::object DebugConfig::AsJSON() const {
break;
}
}
if (disagg_config.kind != DisaggRequestKind::kNone) {
config["disagg_config"] = picojson::value(disagg_config.AsJSON());
}
return config;
}

Expand Down
30 changes: 30 additions & 0 deletions cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ enum class SpecialRequestKind : int {
kQueryEngineMetrics = 1,
};

enum class DisaggRequestKind : int {
kNone = 0,
kPreparePrefill = 1,
kRemotePrefill = 2,
kStartDecode = 3,
};

/*! \brief Controls the behavior of inference with grammar constraint. */
enum class GrammarExecutionMode : int {
/*! \brief If grammar is provided for a request, use the grammar to constrain the output token. */
Expand All @@ -55,6 +62,28 @@ enum class GrammarExecutionMode : int {
kJumpForward = 1,
};

/*! \brief The config for disaggregation requests. */
class DisaggConfig {
public:
DisaggRequestKind kind = DisaggRequestKind::kNone;
std::vector<IntTuple> kv_append_metadata;
// "kv_window_begin" and "kv_window_end" denote the KV interval of interests.
// "kv_window_end" supports Python style negative indexing.
// The concrete meaning varies for different special request kind:
// - For "prepare_prefill", the begin is always 0, and "[0:end]" denotes
// the KV range to prefill on a prefill instance.
// - For "remote_prefill", "[begin:end]" means the KV range to compute prefill
// and send to the decode instance.
// - For "start_decode", the end is always nullopt, and "[begin:]" denotes
// the KV range to prefill locally on the decode instance.
std::optional<int> kv_window_begin = std::nullopt;
std::optional<int> kv_window_end = std::nullopt;
std::optional<int> dst_group_offset = std::nullopt;

static Result<DisaggConfig> FromJSON(const picojson::object& config_json);
picojson::object AsJSON() const;
};

/*! \brief The debug configuration of a request. */
class DebugConfig {
public:
Expand All @@ -63,6 +92,7 @@ class DebugConfig {
SpecialRequestKind special_request = SpecialRequestKind::kNone;
/*! \brief The grammar execution mode. */
GrammarExecutionMode grammar_execution_mode = GrammarExecutionMode::kJumpForward;
DisaggConfig disagg_config;

/*!
* \brief Create debug config from JSON.
Expand Down
39 changes: 39 additions & 0 deletions cpp/serve/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,45 @@ namespace serve {

TVM_REGISTER_OBJECT_TYPE(DataNode);

std::pair<Array<Data>, Array<Data>> SplitData(const Array<Data>& original_data, int total_length,
int split_pos) {
CHECK_GE(split_pos, 0);
CHECK_GE(total_length, split_pos)
<< "Cannot truncate when the current length is already less than the target length";
std::vector<Data> lhs(original_data.begin(), original_data.end());
std::vector<Data> rhs;
while (total_length > split_pos) {
ICHECK(!lhs.empty());
Data last_data = lhs.back();
int last_data_length = last_data->GetLength();
ICHECK_GE(total_length - last_data_length, 0);
if (total_length - last_data_length >= split_pos) {
// Pop the entire last data.
rhs.push_back(lhs.back());
lhs.pop_back();
total_length -= last_data_length;
continue;
}
// Partially truncate the last data.
const auto* token_data = last_data.as<TokenDataNode>();
CHECK(token_data != nullptr) << "Only TokenData supports partial truncation.";
int length_to_truncate = total_length - split_pos;
CHECK_GT(length_to_truncate, 0);
CHECK_LT(length_to_truncate, last_data_length);
TokenData lhs_token_data(
IntTuple{token_data->token_ids.begin(), token_data->token_ids.end() - length_to_truncate});
TokenData rhs_token_data(
IntTuple{token_data->token_ids.end() - length_to_truncate, token_data->token_ids.end()});
CHECK_EQ(total_length - last_data_length + lhs_token_data->GetLength(), split_pos);
lhs.pop_back();
lhs.push_back(lhs_token_data);
rhs.push_back(rhs_token_data);
std::reverse(rhs.begin(), rhs.end());
total_length = split_pos;
}
return {lhs, rhs};
}

/****************** TextData ******************/

TVM_REGISTER_OBJECT_TYPE(TextDataNode);
Expand Down
4 changes: 4 additions & 0 deletions cpp/serve/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class Data : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Data, ObjectRef, DataNode);
};

/*! \brief Split the given data array into two arrays at the "split_pos" position. */
std::pair<Array<Data>, Array<Data>> SplitData(const Array<Data>& original_data, int total_length,
int split_pos);

/****************** TextDataNode ******************/

/*! \brief The class of text data, containing a text string. */
Expand Down
Loading

0 comments on commit 5c9ebcb

Please sign in to comment.