Skip to content

Commit

Permalink
[Serving][Grammar] Integration of JSON schema generation (#2030)
Browse files Browse the repository at this point in the history
Previous PR #1983 introduced a transformation from json schema
to BNF grammar.

This PR further integrates the grammar from json schema to the
generation pipeline, so that the engine now supports json schema
output. GrammarStateInitContexts are stored in a cache, so it will not
be created again with the same schema.

Interface:

- Python
```
@DataClass
class ResponseFormat:
    type: Literal["text", "json_object"] = "text"
    schema: Optional[str] = None
```

- Rest API
```
class RequestResponseFormat(BaseModel):
    type: Literal["text", "json_object"] = "text"
    json_schema: Optional[str] = Field(default=None, alias="schema")

class CompletionRequest(BaseModel):
    ...
    response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat)

class ChatCompletionRequest(BaseModel):
    ...
    response_format: RequestResponseFormat = Field(default_factory=RequestResponseFormat)
```

Performance:

We only tests single-batch performance now to show the overhead in latency.

- Model: `Llama-2-7b-chat-hf-q4f16_1`
- GPU: `NVIDIA GeForce RTX 3080`
- CPU: `AMD Ryzen 9 5900X 12-Core Processor`

```
JSON ON Batch=1
Average prefill tokens: 651.0000 tok/req
Average decode tokens: 499.0000 tok/req
Single token prefill latency: 0.3140 ms/tok
Single token decode latency: 8.6831 ms/tok
Prefill token throughput: 3184.8002 tok/s
Decode token throughput: 116.6039 tok/s

JSON OFF Batch=1
Average prefill tokens: 651.0000 tok/req
Average decode tokens: 499.0000 tok/req
Single token prefill latency: 0.3098 ms/tok
Single token decode latency: 8.6823 ms/tok
Prefill token throughput: 3227.8141 tok/s
Decode token throughput: 116.9251 tok/s
```

This PR also does these bug fixes / changes:
- Changed the structure of the converted grammar from schema
to avoid large amount of uncertain tokens, which caused a
performance degradation
  • Loading branch information
Ubospica authored Mar 27, 2024
1 parent a6d31d7 commit f2518ab
Show file tree
Hide file tree
Showing 24 changed files with 734 additions and 139 deletions.
16 changes: 8 additions & 8 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ GenerationConfig::GenerationConfig(String config_json_str) {
CHECK(response_format_json["type"].is<std::string>());
response_format.type = response_format_json["type"].get<std::string>();
}
if (response_format_json.count("json_schema")) {
if (response_format_json["json_schema"].is<picojson::null>()) {
response_format.json_schema = NullOpt;
if (response_format_json.count("schema")) {
if (response_format_json["schema"].is<picojson::null>()) {
response_format.schema = NullOpt;
} else {
CHECK(response_format_json["json_schema"].is<std::string>());
response_format.json_schema = response_format_json["json_schema"].get<std::string>();
CHECK(response_format_json["schema"].is<std::string>());
response_format.schema = response_format_json["schema"].get<std::string>();
}
}
n->response_format = response_format;
Expand Down Expand Up @@ -194,9 +194,9 @@ String GenerationConfigNode::AsJSONString() const {

picojson::object response_format;
response_format["type"] = picojson::value(this->response_format.type);
response_format["json_schema"] = this->response_format.json_schema
? picojson::value(this->response_format.json_schema.value())
: picojson::value();
response_format["schema"] = this->response_format.schema
? picojson::value(this->response_format.schema.value())
: picojson::value();
config["response_format"] = picojson::value(response_format);

return picojson::value(config).serialize(true);
Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using namespace tvm::runtime;
/*! \brief The response format of a request. */
struct ResponseFormat {
String type = "text";
Optional<String> json_schema = NullOpt;
Optional<String> schema = NullOpt;
};

/*! \brief The generation configuration of a request. */
Expand Down
28 changes: 22 additions & 6 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/runtime/threading_backend.h>

#include <optional>
#include <tuple>
#include <unordered_set>

Expand Down Expand Up @@ -61,8 +62,7 @@ class EngineImpl : public Engine {
this->trace_recorder_ = trace_recorder;
this->tokenizer_ = Tokenizer::FromPath(tokenizer_path);
this->token_table_ = tokenizer_->TokenTable();
this->json_grammar_state_init_ctx_ =
GrammarStateMatcher::CreateInitContext(BNFGrammar::GetGrammarOfJSON(), this->token_table_);
this->grammar_init_context_storage_ = GrammarInitContextStorage(this->token_table_);
// Step 2. Initialize each model independently.
// Create the logit processor and sampler.
this->models_.clear();
Expand Down Expand Up @@ -160,11 +160,13 @@ class EngineImpl : public Engine {

int n = request->generation_cfg->n;
int rng_seed = request->generation_cfg->seed;
auto grammar_state_init_ctx =
ResponseFormatToGrammarInitContext(request->generation_cfg->response_format);

std::vector<RequestStateEntry> rsentries;
// Create the request state entry for the input.
rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), rng_seed,
token_table_, json_grammar_state_init_ctx_);
token_table_, grammar_state_init_ctx);
if (n > 1) {
// Then create a request state entry for each parallel generation branch.
// We add a offset to the rng seed so that to make generations different.
Expand All @@ -173,7 +175,7 @@ class EngineImpl : public Engine {
for (int i = 0; i < n; ++i) {
rsentries[0]->child_indices.push_back(rsentries.size());
rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(),
rng_seed + i + 1, token_table_, json_grammar_state_init_ctx_,
rng_seed + i + 1, token_table_, grammar_state_init_ctx,
/*parent_idx=*/0);
}
}
Expand Down Expand Up @@ -247,6 +249,20 @@ class EngineImpl : public Engine {
std::max(max_concurrency - host_cpu_usage, 1), kv_cache_config_->max_num_sequence));
}

/*! \brief Create a grammar init context according to the response format. If the response format
* is not JSON, return std::nullopt. */
std::optional<std::shared_ptr<GrammarStateInitContext>> ResponseFormatToGrammarInitContext(
const ResponseFormat& response_format) {
if (response_format.type != "json_object") {
return std::nullopt;
} else if (!response_format.schema) {
return grammar_init_context_storage_->GetInitContextForJSON();
} else {
return grammar_init_context_storage_->GetInitContextForJSONSchema(
response_format.schema.value());
}
}

// Engine state, managing requests and request states.
EngineState estate_;
// Configurations and singletons
Expand All @@ -255,8 +271,8 @@ class EngineImpl : public Engine {
int max_single_sequence_length_;
Tokenizer tokenizer_;
std::vector<std::string> token_table_;
// The initial context for the grammar state matching of JSON.
std::shared_ptr<GrammarStateInitContext> json_grammar_state_init_ctx_;
// Helper to get the grammar init context for requests.
GrammarInitContextStorage grammar_init_context_storage_;
// Models
Array<Model> models_;
// Workspace of each model.
Expand Down
2 changes: 2 additions & 0 deletions cpp/serve/engine_actions/action_commons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ void ProcessFinishedRequestStateEntries(std::vector<RequestStateEntry> finished_
for (const RequestStateEntry& entry : rstate->entries) {
estate->stats.total_decode_length += entry->mstates[0]->committed_tokens.size();
}
// For a request, the first token in committed_tokens is generated by prefilling
// and the rest are generated by decoding. So we subtract the first token.
estate->stats.total_decode_length -= rsentry->request->generation_cfg->n;
}
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/batch_decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class BatchDecodeActionObj : public EngineActionObj {
// - Compute embeddings.
RECORD_EVENT(trace_recorder_, request_ids, "start embedding");
ObjectRef embeddings =
models_[0]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}});
models_[0]->TokenEmbed({IntTuple(input_tokens.begin(), input_tokens.end())});
RECORD_EVENT(trace_recorder_, request_ids, "finish embedding");

// - Invoke model decode.
Expand Down
28 changes: 28 additions & 0 deletions cpp/serve/grammar/grammar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,34 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromJSON").set_body_typed([](String jso
return BNFGrammar::FromJSON(json_string);
});

BNFGrammar BNFGrammar::FromSchema(const String& schema, int indent,
Optional<Array<String>> separators, bool strict_mode) {
static const PackedFunc* json_schema_to_ebnf = Registry::Get("mlc.serve.json_schema_to_ebnf");
CHECK(json_schema_to_ebnf != nullptr) << "mlc.serve.json_schema_to_ebnf is not registered.";

String ebnf_string;

// Convert the indent parameter to NullOpt for sending it to the PackedFunc.
if (indent == -1) {
// The conversion from TVMRetValue to String is ambiguous, so we call the conversion function
// explicitly
ebnf_string =
((*json_schema_to_ebnf)(schema, Optional<ObjectRef>(NullOpt), separators, strict_mode)
.
operator String());
} else {
ebnf_string = (*json_schema_to_ebnf)(schema, indent, separators, strict_mode).operator String();
;
}
return FromEBNFString(ebnf_string);
}

TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema")
.set_body_typed([](const String& schema, int indent, Optional<Array<String>> separators,
bool strict_mode) {
return BNFGrammar::FromSchema(schema, indent, separators, strict_mode);
});

const std::string kJSONGrammarString = R"(
main ::= (
"{" ws members_or_embrace |
Expand Down
23 changes: 21 additions & 2 deletions cpp/serve/grammar/grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace mlc {
namespace llm {
namespace serve {

using namespace tvm;
using namespace tvm::runtime;

/*!
Expand Down Expand Up @@ -182,7 +183,7 @@ class BNFGrammar : public ObjectRef {
* \param simplify Whether to simplify the grammar to make matching more efficient. Default: true.
* Not implemented yet.
*/
static BNFGrammar FromEBNFString(const String& ebnf_string, const String& main_rule,
static BNFGrammar FromEBNFString(const String& ebnf_string, const String& main_rule = "main",
bool normalize = true, bool simplify = true);

/*!
Expand All @@ -192,7 +193,25 @@ class BNFGrammar : public ObjectRef {
*/
static BNFGrammar FromJSON(const String& json_string);

/*
/*!
* \brief Construct a BNF grammar from the json schema string. The schema string should be in the
* format of the schema of a JSON file. We will parse the schema and generate a BNF grammar.
* \param schema The schema string.
* \param indent The number of spaces for indentation. If -1, the output will be in one line.
* Default: -1.
* \param separators Two separators used in the schema: comma and colon. Examples: {",", ":"},
* {", ", ": "}. If NullOpt, the default separators will be used: {",", ": "} when the indent
* is not -1, and {", ", ": "} otherwise. Default: NullOpt.
* \param strict_mode Whether to use strict mode. In strict mode, the generated grammar will not
* allow unevaluatedProperties and unevaluatedItems, i.e. these will be set to false by default.
* This helps LLM to generate accurate output in the grammar-guided generation with JSON
* schema. Default: true.
*/
static BNFGrammar FromSchema(const String& schema, int indent = -1,
Optional<Array<String>> separators = NullOpt,
bool strict_mode = true);

/*!
* \brief Get the grammar of standard JSON format. We have built-in support for JSON.
*/
static BNFGrammar GetGrammarOfJSON();
Expand Down
98 changes: 80 additions & 18 deletions cpp/serve/grammar/grammar_state_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace serve {
* elements at the end may be popped out, and the last element of the stack will be advanced.
*
* One stack may split since there may be multiple possible next positions. In this case, similar
* stacks with different top elements will be added. When ome stack cannot accept the new character,
* stacks with different top elements will be added. When one stack cannot accept the new character,
* it will be removed from the stacks.
*
* ## Storage of Stacks (see grammar_state_matcher_state.h)
Expand All @@ -59,7 +59,7 @@ namespace serve {
* S ::= "" | [c] [d]
* T ::= [e]
*
* ### Previous step
* ### The previous step
* Previous accepted string: ab
* Previous stack tree:
* A------
Expand All @@ -76,7 +76,7 @@ namespace serve {
* < means the stack top pointers in the previous step.
* The stacks in the previous step is: (A, B, C), (A, D), (A, E)
*
* ### Current step
* ### The current step
* Current accepted string: abc
* Current stack tree:
* A----------------- G<<
Expand All @@ -87,7 +87,7 @@ namespace serve {
*
* F: (rule S, choice 1, element 1)
* G: (rule main, choice 0, element 2) (means the matching process has finished, and will be deleted
* when next char comes)
* when the next char comes)
* H: (rule R, choice 1, element 2)
* I: (rule T, choice 0, element 0)
* << means the stack top pointers in the current step.
Expand Down Expand Up @@ -175,7 +175,7 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm
*/
bool AcceptStopToken();

friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher);
friend IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose);

std::shared_ptr<GrammarStateInitContext> init_ctx_;
int max_rollback_steps_;
Expand Down Expand Up @@ -381,12 +381,12 @@ void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask,
<< "The provied bitmask's shape or dtype is not valid.";

BitsetManager next_token_bitset(reinterpret_cast<uint32_t*>(next_token_bitmask->data),
next_token_bitmask->shape[0]);
next_token_bitmask->shape[0], init_ctx_->vocab_size);

if (rejected_indices.size() == 1 && rejected_indices[0] == -1) {
// If rejected_indices is the universal set, the final accepted token set is just
// accepted_indices
next_token_bitset.Reset(init_ctx_->vocab_size, false);
next_token_bitset.Reset(false);
for (int idx : accepted_indices) {
next_token_bitset.Set(init_ctx_->sorted_token_codepoints[idx].id, true);
}
Expand All @@ -399,7 +399,7 @@ void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask,
}
} else {
// Otherwise, the final rejected token set is (rejected_indices \ accepted_indices)
next_token_bitset.Reset(init_ctx_->vocab_size, true);
next_token_bitset.Reset(true);

auto it_acc = accepted_indices.begin();
for (auto i : rejected_indices) {
Expand Down Expand Up @@ -524,34 +524,96 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugMatchCompleteString")
return MatchCompleteString(matcher, str);
});

/*! \brief Print the accepted and rejected tokens stored in the bitset. For debug purposes. */
void PrintAcceptedRejectedTokens(
const std::shared_ptr<mlc::llm::serve::GrammarStateInitContext>& init_ctx,
const BitsetManager& bitset, int threshold = 500) {
auto vocab_size = init_ctx->vocab_size;
std::vector<int64_t> accepted_ids;
std::vector<int64_t> rejected_ids;
for (int i = 0; i < vocab_size; i++) {
if (bitset[i]) {
accepted_ids.push_back(i);
} else {
rejected_ids.push_back(i);
}
}

if (accepted_ids.size() < threshold) {
std::cerr << "Accepted: ";
for (auto id : accepted_ids) {
std::cerr << "<";
auto token = init_ctx->token_table[id];
if (token.size() == 1 && (static_cast<unsigned char>(token[0]) >= 128 || token[0] == 0)) {
// First cast to unsigned, then cast to int
std::cerr << static_cast<int>(static_cast<unsigned char>(token[0]));
} else {
auto codepoints = Utf8StringToCodepoints(token.c_str());
for (auto c : codepoints) {
std::cerr << CodepointToPrintable(c);
}
}
std::cerr << "> ";
}
std::cerr << "\n";
}

if (rejected_ids.size() < threshold) {
std::cerr << "Rejected: ";
for (auto id : rejected_ids) {
std::cerr << "<";
auto token = init_ctx->token_table[id];
if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) {
std::cerr << (int)(unsigned char)token[0];
} else {
auto codepoints = Utf8StringToCodepoints(token.c_str());
for (auto c : codepoints) {
std::cerr << CodepointToPrintable(c);
}
}
std::cerr << "> ";
}
std::cerr << "\n";
}
}

/*!
* \brief Find the ids of the rejected tokens for the next step. For test purposes.
* \brief Find the ids of the rejected tokens for the next step. For debug purposes.
* \param matcher The matcher to test.
* \param verbose Whether to print information about the timing and results to stderr.
* \returns A tuple of rejected token ids.
*/
IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher) {
IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = false) {
auto init_ctx = matcher.as<GrammarStateMatcherNodeImpl>()->init_ctx_;
auto vocab_size = init_ctx->vocab_size;
auto bitset_size = BitsetManager::GetBitsetSize(vocab_size);
auto bitset_size = BitsetManager::CalculateBufferSize(vocab_size);
auto ndarray = NDArray::Empty(ShapeTuple{static_cast<long>(bitset_size)},
DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0});
auto dltensor = const_cast<DLTensor*>(ndarray.operator->());

auto start = std::chrono::high_resolution_clock::now();
std::chrono::time_point<std::chrono::high_resolution_clock> start, end;
if (verbose) {
start = std::chrono::high_resolution_clock::now();
}
matcher->FindNextTokenBitmask(dltensor);
auto end = std::chrono::high_resolution_clock::now();
std::cerr << "FindNextTokenBitmask takes "
<< std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() << "us";
if (verbose) {
end = std::chrono::high_resolution_clock::now();
}

auto bitset = BitsetManager(reinterpret_cast<uint32_t*>(dltensor->data), bitset_size);
auto bitset = BitsetManager(reinterpret_cast<uint32_t*>(dltensor->data), bitset_size, vocab_size);
std::vector<int64_t> rejected_ids;
for (int i = 0; i < vocab_size; i++) {
if (bitset[i] == 0) {
rejected_ids.push_back(i);
}
}

std::cerr << ", found accepted: " << vocab_size - rejected_ids.size()
<< ", rejected: " << rejected_ids.size() << std::endl;
if (verbose) {
std::cerr << "FindNextTokenBitmask takes "
<< std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() << "us"
<< ", found accepted: " << vocab_size - rejected_ids.size()
<< ", rejected: " << rejected_ids.size() << std::endl;
}

auto ret = IntTuple(rejected_ids);
return ret;
Expand Down
Loading

0 comments on commit f2518ab

Please sign in to comment.