Skip to content

Commit

Permalink
support in streaming for meta refernce
Browse files Browse the repository at this point in the history
  • Loading branch information
dineshyv committed Jan 29, 2025
1 parent 813c8a6 commit ead6cb9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
18 changes: 18 additions & 0 deletions llama_stack/providers/inline/inference/meta_reference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,14 @@ async def completion(
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
def impl():
stop_reason = None
input_token_count = 0
output_token_count = 0
usage_statistics = None

for token_result in self.generator.completion(request):
if input_token_count == 0:
input_token_count = token_result.input_token_count
output_token_count += len(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
Expand All @@ -192,17 +198,29 @@ def impl():
}
)
]
else:
usage_statistics = UsageStatistics(
prompt_tokens=input_token_count,
completion_tokens=output_token_count,
total_tokens=input_token_count + output_token_count,
)

yield CompletionResponseStreamChunk(
delta=text,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
usage=usage_statistics,
)

if stop_reason is None:
yield CompletionResponseStreamChunk(
delta="",
stop_reason=StopReason.out_of_tokens,
usage=UsageStatistics(
prompt_tokens=input_token_count,
completion_tokens=output_token_count,
total_tokens=input_token_count + output_token_count,
),
)

if self.config.create_distributed_process_group:
Expand Down
9 changes: 1 addition & 8 deletions llama_stack/providers/utils/inference/openai_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,14 @@ def process_chat_completion_response(
raw_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)
usage_statistics = None
if response.usage:
usage_statistics = UsageStatistics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=None,
usage_statistics=usage_statistics,
usage=get_usage_statistics(response),
)


Expand Down

0 comments on commit ead6cb9

Please sign in to comment.