Skip to content

Commit

Permalink
add prefix cache stats to usage
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock committed Jul 24, 2024
1 parent 12b59b4 commit bdb3082
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 15 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,15 @@ class EngineOutput:
token_ids (List[int]): the output token ids.
num_token (int): the length of output token, for turbomind, num_token
may not equal to the length of token_ids
num_prefix_cached_token (int): the length of prefix cached tokens
logprobs (List[Dict[int, float]]): the top logprobs for each output
position.
"""
status: ResponseType
token_ids: List[int]
num_token: int
logprobs: List[Dict[int, float]] = None
num_prefix_cached_token: int = None


@dataclass
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class GenOut:
finish_reason: Optional[Literal['stop', 'length']] = None
token_ids: List[int] = None
logprobs: List[Dict[int, float]] = None
prefix_cached_token_len: int = None


class Session:
Expand Down Expand Up @@ -616,6 +617,7 @@ async def generate(
state = DetokenizeState(len(input_ids))
start_ids_offset = state.ids_offset
response = ''
prefix_cached_token_len = None
async for outputs in generator.async_stream_infer(
session_id=session_id,
**prompt_input,
Expand All @@ -627,6 +629,7 @@ async def generate(
step=self.id2step[str(session_id)]):
# decode res
res, tokens = input_ids + outputs.token_ids, outputs.num_token # noqa
prefix_cached_token_len = outputs.num_prefix_cached_token
if len(res) <= state.ids_offset:
continue

Expand Down Expand Up @@ -655,7 +658,7 @@ async def generate(
if not response.endswith('�'):
response = '' # avaid returning the last response twice
yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason)
len(input_ids), tokens, finish_reason, prefix_cached_token_len=prefix_cached_token_len)
# update step
self.id2step[str(session_id)] += len(input_ids) + tokens
if sequence_end:
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
prompt_tokens=final_res.input_token_len,
completion_tokens=final_res.generate_token_len,
total_tokens=total_tokens,
prefix_cached_tokens=final_res.prefix_cached_token_len
)
response = ChatCompletionResponse(
id=request_id,
Expand Down Expand Up @@ -590,6 +591,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
prompt_tokens=final_res.input_token_len,
completion_tokens=final_res.generate_token_len,
total_tokens=total_tokens,
prefix_cached_tokens=final_res.prefix_cached_token_len,
)
response = ChatCompletionResponse(
id=request_id,
Expand Down Expand Up @@ -693,6 +695,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
prompt_tokens=final_res.input_token_len,
completion_tokens=final_res.generate_token_len,
total_tokens=total_tokens,
prefix_cached_tokens=final_res.prefix_cached_token_len,
)
response_json = create_stream_response_json(
index=0,
Expand Down Expand Up @@ -738,6 +741,7 @@ async def _inner_call(i, generator):
usage.prompt_tokens += final_res.input_token_len
usage.completion_tokens += final_res.generate_token_len
usage.total_tokens += total_tokens
usage.prefix_cached_tokens = final_res.prefix_cached_token_len

await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(len(generators))])
Expand Down Expand Up @@ -882,6 +886,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
prompt_tokens=final_res.input_token_len,
completion_tokens=final_res.generate_token_len,
total_tokens=total_tokens,
prefix_cached_tokens=final_res.prefix_cached_token_len,
)
response_json = create_stream_response_json(
index=0,
Expand Down Expand Up @@ -942,6 +947,7 @@ async def _inner_call(i, generator):
usage.prompt_tokens += final_res.input_token_len
usage.completion_tokens += final_res.generate_token_len
usage.total_tokens += total_tokens
usage.prefix_cached_tokens = final_res.prefix_cached_token_len

await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(len(generators))])
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/serve/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
prefix_cached_tokens: Optional[int] = None


class ChatCompletionRequestQos(BaseModel):
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ async def async_stream_infer(self,

output_ids = outputs['output_ids'][:, 0, :]
sequence_length = outputs['sequence_length'].long()[:, 0]
prefix_cache_length = outputs['prefix_cache_length'].long()[:, 0].item()
output_ids = [
output_id[s:l] for output_id, s, l in zip(
output_ids, seq_start, sequence_length)
Expand Down Expand Up @@ -763,6 +764,9 @@ async def async_stream_infer(self,
if out_logprobs:
output_token_len = len(outputs.token_ids)
outputs.logprobs = out_logprobs[:output_token_len]

if self.tm_model.config.enable_prefix_caching:
outputs.num_prefix_cached_token = prefix_cache_length

yield outputs

Expand Down Expand Up @@ -839,6 +843,7 @@ def stream_infer(self,

output_ids = outputs['output_ids'][:, 0, :]
sequence_length = outputs['sequence_length'].long()[:, 0]
prefix_cache_length = outputs['prefix_cache_length'].long()[:, 0].item()
output_ids = [
output_id[s:l] for output_id, s, l in zip(
output_ids, seq_start, sequence_length)
Expand Down Expand Up @@ -876,6 +881,9 @@ def stream_infer(self,
if out_logprobs:
output_token_len = len(outputs.token_ids)
outputs.logprobs = out_logprobs[:output_token_len]

if self.tm_model.config.enable_prefix_caching:
outputs.num_prefix_cached_token = prefix_cache_length

yield outputs

Expand Down
10 changes: 6 additions & 4 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1275,12 +1275,14 @@ auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
int* output_ptr = h_output_ids_;
for (int i = 0; i < batch_size - g.partial; ++i) {
if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
auto output_ids = state_->requests[i]->outputs[rank_].getPtr<int>("output_ids");
auto output_len = state_->requests[i]->outputs[rank_].getPtr<int>("sequence_length");
const int count = state_->h_context_length[i];
auto output_ids = state_->requests[i]->outputs[rank_].getPtr<int>("output_ids");
auto output_len = state_->requests[i]->outputs[rank_].getPtr<int>("sequence_length");
auto prefix_cache_len = state_->requests[i]->outputs[rank_].getPtr<int>("prefix_cache_length");
const int count = state_->h_context_length[i];
// TODO: sync history output tokens at when receiving the request and copy the last token here
std::copy(output_ptr, output_ptr + count, output_ids);
*output_len = count;
*output_len = count;
*prefix_cache_len = state_->sequences[i]->prefix_cache_len;
}
output_ptr += session_len_;
}
Expand Down
3 changes: 2 additions & 1 deletion src/turbomind/models/llama/SequenceManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,8 @@ auto SequenceManager::Materialize(Sequences sequences,
if (!sequences[i]->prompt.empty() && sequences[i]->blocks.empty()) {
auto& seq = const_cast<Sequence&>(*sequences[i]);
block_trie_->match(seq);
seq.cache_len = seq.blocks.size() * block_seq_len_;
seq.cache_len = seq.blocks.size() * block_seq_len_;
seq.prefix_cache_len = seq.cache_len;
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/turbomind/models/llama/SequenceManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ namespace turbomind {

struct Sequence {

enum Status
{
enum Status {
kCached = 0,
kLocked,
kActive
Expand All @@ -29,7 +28,8 @@ struct Sequence {

mutable std::vector<int> tokens; // update by user

mutable int cache_len = 0;
mutable int cache_len = 0;
mutable int prefix_cache_len = 0;

// additional data kept round-to-round
mutable std::vector<std::byte> random_state; // update by user
Expand Down
12 changes: 11 additions & 1 deletion src/turbomind/triton_backend/llama/LlamaTritonModelInstance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,14 @@ LlamaTritonModelInstance<T>::forward(std::shared_ptr<std::unordered_map<std::str
ft::Tensor{ft::MEMORY_CPU,
ft::TYPE_UINT32,
std::vector<size_t>{request_batch_size, beam_width},
d_sequence_lengths_}}};
d_sequence_lengths_}},
{"prefix_cache_length",
ft::Tensor{ft::MEMORY_CPU,
ft::TYPE_UINT32,
std::vector<size_t>{request_batch_size, beam_width},
d_prefix_cache_lengths_}}

};

if (input_tensors->count("is_return_log_probs") && *((bool*)input_tensors->at("is_return_log_probs").data)) {
output_tensors.insert({"output_log_probs",
Expand Down Expand Up @@ -249,6 +256,8 @@ void LlamaTritonModelInstance<T>::allocateBuffer(const size_t request_batch_size
{
d_output_ids_ = (int*)std::realloc(d_output_ids_, sizeof(int) * request_batch_size * beam_width * session_len);
d_sequence_lengths_ = (int*)std::realloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width);
d_prefix_cache_lengths_ =
(int*)std::realloc(d_prefix_cache_lengths_, sizeof(int) * request_batch_size * beam_width);

// d_output_log_probs_ = (float*)(allocator_->reMalloc(
// d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * session_len, false));
Expand All @@ -265,6 +274,7 @@ void LlamaTritonModelInstance<T>::freeBuffer()
{
std::free(d_output_ids_);
std::free(d_sequence_lengths_);
std::free(d_prefix_cache_lengths_);
allocator_->free((void**)(&d_output_log_probs_));
allocator_->free((void**)(&d_cum_log_probs_));
std::free(h_total_output_lengths_);
Expand Down
11 changes: 6 additions & 5 deletions src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ struct LlamaTritonModelInstance: AbstractTransformerModelInstance {
float* d_top_p_min_ = nullptr;
int* d_top_p_reset_ids_ = nullptr;

int* d_output_ids_ = nullptr;
int* d_sequence_lengths_ = nullptr;
float* d_output_log_probs_ = nullptr;
float* d_cum_log_probs_ = nullptr;
float* d_output_logits_ = nullptr;
int* d_output_ids_ = nullptr;
int* d_sequence_lengths_ = nullptr;
int* d_prefix_cache_lengths_ = nullptr;
float* d_output_log_probs_ = nullptr;
float* d_cum_log_probs_ = nullptr;
float* d_output_logits_ = nullptr;

float* h_logprob_vals_ = nullptr;
uint32_t* h_logprob_indexes_ = nullptr;
Expand Down

0 comments on commit bdb3082

Please sign in to comment.