Skip to content

Commit

Permalink
add usage to chat completion stream response
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock committed Jul 25, 2024
1 parent 1401e01 commit 1e960c3
Showing 1 changed file with 41 additions and 8 deletions.
49 changes: 41 additions & 8 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
Expand All @@ -309,16 +310,31 @@ def create_stream_response_json(
created=created_time,
model=model_name,
choices=[choice_data],
usage=usage,
)
response_json = response.model_dump_json()
response_json = response.model_dump_json(exclude_none=True)

return response_json

async def completion_stream_generator() -> AsyncGenerator[str, None]:
async for res in result_generator:
usage = None
if res.finish_reason is not None:
final_res = res
total_tokens = sum([
final_res.history_token_len, final_res.input_token_len,
final_res.generate_token_len
])
usage = UsageInfo(
prompt_tokens=final_res.input_token_len,
completion_tokens=final_res.generate_token_len,
total_tokens=total_tokens,
prefix_cached_tokens=final_res.prefix_cached_token_len,
)
response_json = create_stream_response_json(
index=0,
text=res.response,
usage=usage,
)
yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n'
Expand Down Expand Up @@ -480,10 +496,12 @@ async def chat_completions_v1(request: ChatCompletionRequest,
)

def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
logprobs: Optional[LogProbs] = None) -> str:
index: int,
text: str,
finish_reason: Optional[str] = None,
logprobs: Optional[LogProbs] = None,
usage: Optional[UsageInfo] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role='assistant', content=text),
Expand All @@ -494,8 +512,9 @@ def create_stream_response_json(
created=created_time,
model=model_name,
choices=[choice_data],
usage=usage,
)
response_json = response.model_dump_json()
response_json = response.model_dump_json(exclude_none=True)

return response_json

Expand All @@ -506,12 +525,26 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
logprobs = _create_chat_completion_logprobs(
VariableInterface.async_engine.tokenizer, res.token_ids,
res.logprobs)

usage = None
if res.finish_reason is not None:
final_res = res
total_tokens = sum([
final_res.history_token_len, final_res.input_token_len,
final_res.generate_token_len
])
usage = UsageInfo(
prompt_tokens=final_res.input_token_len,
completion_tokens=final_res.generate_token_len,
total_tokens=total_tokens,
prefix_cached_tokens=final_res.prefix_cached_token_len,
)
response_json = create_stream_response_json(
index=0,
text=res.response,
finish_reason=res.finish_reason,
logprobs=logprobs)
logprobs=logprobs,
usage=usage,
)
yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n'

Expand Down

0 comments on commit 1e960c3

Please sign in to comment.