diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index a2f92b6e42..8e66839316 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
- "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
+ "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
},
"servers": [
{
@@ -2830,8 +2830,11 @@
"CompletionResponse": {
"type": "object",
"properties": {
- "completion_message": {
- "$ref": "#/components/schemas/CompletionMessage"
+ "content": {
+ "type": "string"
+ },
+ "stop_reason": {
+ "$ref": "#/components/schemas/StopReason"
},
"logprobs": {
"type": "array",
@@ -2842,7 +2845,8 @@
},
"additionalProperties": false,
"required": [
- "completion_message"
+ "content",
+ "stop_reason"
],
"title": "Completion response."
},
@@ -6075,49 +6079,49 @@
],
"tags": [
{
- "name": "Evaluations"
- },
- {
- "name": "Inspect"
+ "name": "Models"
},
{
"name": "RewardScoring"
},
{
- "name": "Datasets"
+ "name": "MemoryBanks"
},
{
- "name": "Models"
+ "name": "Shields"
},
{
- "name": "Telemetry"
+ "name": "SyntheticDataGeneration"
},
{
- "name": "PostTraining"
+ "name": "Inference"
},
{
- "name": "SyntheticDataGeneration"
+ "name": "Inspect"
},
{
"name": "BatchInference"
},
{
- "name": "Inference"
+ "name": "Memory"
+ },
+ {
+ "name": "Datasets"
},
{
"name": "Agents"
},
{
- "name": "Memory"
+ "name": "PostTraining"
},
{
- "name": "Safety"
+ "name": "Telemetry"
},
{
- "name": "Shields"
+ "name": "Safety"
},
{
- "name": "MemoryBanks"
+ "name": "Evaluations"
},
{
"name": "BuiltinTool",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index c9822d6ca9..906d3934a7 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -501,14 +501,17 @@ components:
CompletionResponse:
additionalProperties: false
properties:
- completion_message:
- $ref: '#/components/schemas/CompletionMessage'
+ content:
+ type: string
logprobs:
items:
$ref: '#/components/schemas/TokenLogProbs'
type: array
+ stop_reason:
+ $ref: '#/components/schemas/StopReason'
required:
- - completion_message
+ - content
+ - stop_reason
title: Completion response.
type: object
CompletionResponseStreamChunk:
@@ -2507,7 +2510,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
- \ draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109"
+ \ draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@@ -3712,21 +3715,21 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
-- name: Evaluations
-- name: Inspect
-- name: RewardScoring
-- name: Datasets
- name: Models
-- name: Telemetry
-- name: PostTraining
+- name: RewardScoring
+- name: MemoryBanks
+- name: Shields
- name: SyntheticDataGeneration
-- name: BatchInference
- name: Inference
-- name: Agents
+- name: Inspect
+- name: BatchInference
- name: Memory
+- name: Datasets
+- name: Agents
+- name: PostTraining
+- name: Telemetry
- name: Safety
-- name: Shields
-- name: MemoryBanks
+- name: Evaluations
- description:
name: BuiltinTool
- description: AgentCreateResponse: ...
- # This method is not `async def` because it can result in either an
- # `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create")
- def create_agent_turn(
+ async def create_agent_turn(
self,
agent_id: str,
session_id: str,
diff --git a/llama_stack/apis/agents/client.py b/llama_stack/apis/agents/client.py
index 32bc9abdd5..b454473289 100644
--- a/llama_stack/apis/agents/client.py
+++ b/llama_stack/apis/agents/client.py
@@ -67,14 +67,14 @@ async def create_agent_session(
response.raise_for_status()
return AgentSessionCreateResponse(**response.json())
- def create_agent_turn(
+ async def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
- return self._nonstream_agent_turn(request)
+ return await self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
@@ -126,7 +126,7 @@ async def _run_agent(
for content in user_prompts:
cprint(f"User> {content}", color="white", attrs=["bold"])
- iterator = api.create_agent_turn(
+ iterator = await api.create_agent_turn(
AgentTurnCreateRequest(
agent_id=create_response.agent_id,
session_id=session_response.session_id,
diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py
index 79d2cc02ca..90636fa363 100644
--- a/llama_stack/apis/inference/client.py
+++ b/llama_stack/apis/inference/client.py
@@ -42,10 +42,10 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass
- def completion(self, request: CompletionRequest) -> AsyncGenerator:
+ async def completion(self, request: CompletionRequest) -> AsyncGenerator:
raise NotImplementedError()
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
@@ -139,7 +139,8 @@ async def run_main(
else:
logprobs_config = None
- iterator = client.chat_completion(
+ assert stream, "Non streaming not supported here"
+ iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 588dd37caa..5895e528e5 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -88,7 +88,8 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel):
"""Completion response."""
- completion_message: CompletionMessage
+ content: str
+ stop_reason: StopReason
logprobs: Optional[List[TokenLogProbs]] = None
@@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel):
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""
- completion_message_batch: List[CompletionMessage]
+ batch: List[CompletionResponse]
@json_schema_type
@@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel):
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
- completion_message_batch: List[CompletionMessage]
+ batch: List[ChatCompletionResponse]
@json_schema_type
@@ -181,10 +182,8 @@ def get_model(self, identifier: str) -> ModelDef: ...
class Inference(Protocol):
model_store: ModelStore
- # This method is not `async def` because it can result in either an
- # `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/completion")
- def completion(
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -196,7 +195,7 @@ def completion(
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index cf62da1d01..a78e808d08 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -70,7 +70,7 @@ async def shutdown(self) -> None:
async def register_model(self, model: ModelDef) -> None:
await self.routing_table.register_model(model)
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
@@ -93,11 +93,11 @@ def chat_completion(
)
provider = self.routing_table.get_provider_impl(model)
if stream:
- return (chunk async for chunk in provider.chat_completion(**params))
+ return (chunk async for chunk in await provider.chat_completion(**params))
else:
- return provider.chat_completion(**params)
+ return await provider.chat_completion(**params)
- def completion(
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -114,9 +114,9 @@ def completion(
logprobs=logprobs,
)
if stream:
- return (chunk async for chunk in provider.completion(**params))
+ return (chunk async for chunk in await provider.completion(**params))
else:
- return provider.completion(**params)
+ return await provider.completion(**params)
async def embeddings(
self,
diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py
index 22f87ef6bd..8440ecc205 100644
--- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py
@@ -47,7 +47,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
self.client.close()
- def completion(
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -283,7 +283,7 @@ def _tools_to_tool_config(
)
return tool_config
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py
index 1410511864..9f50ad227d 100644
--- a/llama_stack/providers/adapters/inference/databricks/databricks.py
+++ b/llama_stack/providers/adapters/inference/databricks/databricks.py
@@ -48,7 +48,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass
- def completion(
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -58,7 +58,7 @@ def completion(
) -> AsyncGenerator:
raise NotImplementedError()
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
@@ -84,7 +84,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request, client)
else:
- return self._nonstream_chat_completion(request, client)
+ return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py
index c82012cba6..537f3a6b4f 100644
--- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py
@@ -51,7 +51,7 @@ async def initialize(self) -> None:
async def shutdown(self) -> None:
pass
- def completion(
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -61,7 +61,7 @@ def completion(
) -> AsyncGenerator:
raise NotImplementedError()
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
@@ -87,7 +87,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request, client)
else:
- return self._nonstream_chat_completion(request, client)
+ return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py
index c50c869fd5..3a3e4b4516 100644
--- a/llama_stack/providers/adapters/inference/ollama/ollama.py
+++ b/llama_stack/providers/adapters/inference/ollama/ollama.py
@@ -84,7 +84,7 @@ async def list_models(self) -> List[ModelDef]:
return ret
- def completion(
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -94,7 +94,7 @@ def completion(
) -> AsyncGenerator:
raise NotImplementedError()
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
@@ -118,7 +118,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request)
else:
- return self._nonstream_chat_completion(request)
+ return await self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py
index cd0afad0cc..3c610099cd 100644
--- a/llama_stack/providers/adapters/inference/tgi/tgi.py
+++ b/llama_stack/providers/adapters/inference/tgi/tgi.py
@@ -66,7 +66,7 @@ async def list_models(self) -> List[ModelDef]:
async def shutdown(self) -> None:
pass
- def completion(
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -76,7 +76,7 @@ def completion(
) -> AsyncGenerator:
raise NotImplementedError()
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
@@ -101,7 +101,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request)
else:
- return self._nonstream_chat_completion(request)
+ return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py
index 750ca126e2..8c73d75ecc 100644
--- a/llama_stack/providers/adapters/inference/together/together.py
+++ b/llama_stack/providers/adapters/inference/together/together.py
@@ -64,7 +64,7 @@ async def completion(
) -> AsyncGenerator:
raise NotImplementedError()
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
@@ -101,7 +101,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request, client)
else:
- return self._nonstream_chat_completion(request, client)
+ return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Together
diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py
index 0d334fdad7..cbc7490fd8 100644
--- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py
+++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py
@@ -424,7 +424,7 @@ async def _run(
stop_reason = None
with tracing.span("inference"):
- async for chunk in self.inference_api.chat_completion(
+ async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py
index 5a209d0b76..8b3ece978f 100644
--- a/llama_stack/providers/impls/meta_reference/agents/agents.py
+++ b/llama_stack/providers/impls/meta_reference/agents/agents.py
@@ -105,7 +105,7 @@ async def create_agent_session(
session_id=session_id,
)
- def create_agent_turn(
+ async def create_agent_turn(
self,
agent_id: str,
session_id: str,
diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py
index 20a8addc7f..9ca1281768 100644
--- a/llama_stack/providers/impls/meta_reference/inference/generation.py
+++ b/llama_stack/providers/impls/meta_reference/inference/generation.py
@@ -23,11 +23,6 @@
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
-from llama_models.llama3.api.datatypes import (
- InterleavedTextMedia,
- Message,
- ToolPromptFormat,
-)
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
@@ -38,7 +33,11 @@
from pydantic import BaseModel
from termcolor import cprint
+from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.utils.model_utils import model_local_dir
+from llama_stack.providers.utils.inference.prompt_adapter import (
+ chat_completion_request_to_messages,
+)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
@@ -297,15 +296,12 @@ def generate(
if all(eos_reached):
break
- def text_completion(
+ def completion(
self,
- content: InterleavedTextMedia,
- temperature: float = 0.6,
- top_p: float = 0.9,
- max_gen_len: Optional[int] = None,
- logprobs: bool = False,
- echo: bool = False,
+ request: CompletionRequest,
) -> Generator:
+ sampling_params = request.sampling_params
+ max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
@@ -313,26 +309,25 @@ def text_completion(
):
max_gen_len = self.model.params.max_seq_len - 1
- model_input = self.formatter.encode_content(content)
-
+ model_input = self.formatter.encode_content(request.content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
- temperature=temperature,
- top_p=top_p,
- logprobs=logprobs,
- echo=echo,
+ temperature=sampling_params.temperature,
+ top_p=sampling_params.top_p,
+ logprobs=bool(request.logprobs),
+ include_stop_token=True,
+ echo=False,
)
def chat_completion(
self,
- messages: List[Message],
- temperature: float = 0.6,
- top_p: float = 0.9,
- max_gen_len: Optional[int] = None,
- logprobs: bool = False,
- tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
+ request: ChatCompletionRequest,
) -> Generator:
+ messages = chat_completion_request_to_messages(request)
+
+ sampling_params = request.sampling_params
+ max_gen_len = sampling_params.max_tokens
if (
max_gen_len is None
or max_gen_len == 0
@@ -343,12 +338,12 @@ def chat_completion(
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
- tool_prompt_format,
+ request.tool_prompt_format,
),
max_gen_len=max_gen_len,
- temperature=temperature,
- top_p=top_p,
- logprobs=logprobs,
+ temperature=sampling_params.temperature,
+ top_p=sampling_params.top_p,
+ logprobs=bool(request.logprobs),
include_stop_token=True,
)
diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py
index 7edc279d03..34053343e6 100644
--- a/llama_stack/providers/impls/meta_reference/inference/inference.py
+++ b/llama_stack/providers/impls/meta_reference/inference/inference.py
@@ -13,9 +13,6 @@
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
-from llama_stack.providers.utils.inference.prompt_adapter import (
- chat_completion_request_to_messages,
-)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
@@ -58,7 +55,18 @@ async def shutdown(self) -> None:
if self.config.create_distributed_process_group:
self.generator.stop()
- def completion(
+ def check_model(self, request) -> None:
+ model = resolve_model(request.model)
+ if model is None:
+ raise RuntimeError(
+ f"Unknown model: {request.model}, Run `llama model list`"
+ )
+ elif model.descriptor() != self.model.descriptor():
+ raise RuntimeError(
+ f"Model mismatch: {request.model} != {self.model.descriptor()}"
+ )
+
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -66,9 +74,114 @@ def completion(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
- raise NotImplementedError()
+ if logprobs:
+ assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
+
+ request = CompletionRequest(
+ model=model,
+ content=content,
+ sampling_params=sampling_params,
+ stream=stream,
+ logprobs=logprobs,
+ )
+ self.check_model(request)
+
+ if request.stream:
+ return self._stream_completion(request)
+ else:
+ return await self._nonstream_completion(request)
+
+ async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
+ def impl():
+ stop_reason = None
+
+ for token_result in self.generator.completion(request):
+ if token_result.text == "<|eot_id|>":
+ stop_reason = StopReason.end_of_turn
+ text = ""
+ elif token_result.text == "<|eom_id|>":
+ stop_reason = StopReason.end_of_message
+ text = ""
+ else:
+ text = token_result.text
+
+ logprobs = None
+ if stop_reason is None:
+ if request.logprobs:
+ assert len(token_result.logprobs) == 1
+
+ logprobs = [
+ TokenLogProbs(
+ logprobs_by_token={
+ token_result.text: token_result.logprobs[0]
+ }
+ )
+ ]
- def chat_completion(
+ yield CompletionResponseStreamChunk(
+ delta=text,
+ stop_reason=stop_reason,
+ logprobs=logprobs if request.logprobs else None,
+ )
+
+ if stop_reason is None:
+ yield CompletionResponseStreamChunk(
+ delta="",
+ stop_reason=StopReason.out_of_tokens,
+ )
+
+ if self.config.create_distributed_process_group:
+ async with SEMAPHORE:
+ for x in impl():
+ yield x
+ else:
+ for x in impl():
+ yield x
+
+ async def _nonstream_completion(
+ self, request: CompletionRequest
+ ) -> CompletionResponse:
+ def impl():
+ tokens = []
+ logprobs = []
+ stop_reason = None
+
+ tokenizer = self.generator.formatter.tokenizer
+ for token_result in self.generator.completion(request):
+ tokens.append(token_result.token)
+
+ if token_result.token in tokenizer.stop_tokens:
+ # not quite right semantically
+ stop_reason = StopReason.end_of_turn
+
+ if request.logprobs:
+ assert len(token_result.logprobs) == 1
+
+ logprobs.append(
+ TokenLogProbs(
+ logprobs_by_token={
+ token_result.text: token_result.logprobs[0]
+ }
+ )
+ )
+
+ if stop_reason is None:
+ stop_reason = StopReason.out_of_tokens
+
+ content = self.generator.formatter.tokenizer.decode(tokens)
+ return CompletionResponse(
+ content=content,
+ stop_reason=stop_reason,
+ logprobs=logprobs if request.logprobs else None,
+ )
+
+ if self.config.create_distributed_process_group:
+ async with SEMAPHORE:
+ return impl()
+ else:
+ return impl()
+
+ async def chat_completion(
self,
model: str,
messages: List[Message],
@@ -93,16 +206,7 @@ def chat_completion(
stream=stream,
logprobs=logprobs,
)
-
- model = resolve_model(request.model)
- if model is None:
- raise RuntimeError(
- f"Unknown model: {request.model}, Run `llama model list`"
- )
- elif model.descriptor() != self.model.descriptor():
- raise RuntimeError(
- f"Model mismatch: {request.model} != {self.model.descriptor()}"
- )
+ self.check_model(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@@ -111,26 +215,17 @@ def chat_completion(
if request.stream:
return self._stream_chat_completion(request)
else:
- return self._nonstream_chat_completion(request)
+ return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
def impl():
- messages = chat_completion_request_to_messages(request)
-
tokens = []
logprobs = []
stop_reason = None
- for token_result in self.generator.chat_completion(
- messages=messages,
- temperature=request.sampling_params.temperature,
- top_p=request.sampling_params.top_p,
- max_gen_len=request.sampling_params.max_tokens,
- logprobs=request.logprobs,
- tool_prompt_format=request.tool_prompt_format,
- ):
+ for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
@@ -170,8 +265,6 @@ async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
def impl():
- messages = chat_completion_request_to_messages(request)
-
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
@@ -184,14 +277,7 @@ def impl():
stop_reason = None
ipython = False
- for token_result in self.generator.chat_completion(
- messages=messages,
- temperature=request.sampling_params.temperature,
- top_p=request.sampling_params.top_p,
- max_gen_len=request.sampling_params.max_tokens,
- logprobs=request.logprobs,
- tool_prompt_format=request.tool_prompt_format,
- ):
+ for token_result in self.generator.chat_completion(request):
tokens.append(token_result.token)
if not ipython and token_result.text.startswith("<|python_tag|>"):
diff --git a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py
index e8f483f300..7e7831185b 100644
--- a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py
+++ b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py
@@ -7,16 +7,17 @@
import os
from copy import deepcopy
from functools import partial
-from typing import Generator, List, Optional
+from typing import Any, Generator
from llama_models.llama3.api.chat_format import ChatFormat
-from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
+from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
+
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
-from .parallel_utils import InferenceArgs, ModelParallelProcessGroup
+from .parallel_utils import ModelParallelProcessGroup
class ModelRunner:
@@ -24,15 +25,13 @@ def __init__(self, llama):
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
- def __call__(self, task: InferenceArgs):
- return self.llama.chat_completion(
- task.messages,
- task.temperature,
- task.top_p,
- task.max_gen_len,
- task.logprobs,
- task.tool_prompt_format,
- )
+ def __call__(self, req: Any):
+ if isinstance(req, ChatCompletionRequest):
+ return self.llama.chat_completion(req)
+ elif isinstance(req, CompletionRequest):
+ return self.llama.completion(req)
+ else:
+ raise ValueError(f"Unexpected task type {type(req)}")
def init_model_cb(config: MetaReferenceInferenceConfig):
@@ -77,23 +76,18 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_traceback):
self.group.stop()
- def chat_completion(
+ def completion(
self,
- messages: List[Message],
- temperature: float = 0.6,
- top_p: float = 0.9,
- max_gen_len: Optional[int] = None,
- logprobs: bool = False,
- tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
+ request: CompletionRequest,
) -> Generator:
- req_obj = InferenceArgs(
- messages=deepcopy(messages),
- temperature=temperature,
- top_p=top_p,
- max_gen_len=max_gen_len,
- logprobs=logprobs or False,
- tool_prompt_format=tool_prompt_format,
- )
+ req_obj = deepcopy(request)
+ gen = self.group.run_inference(req_obj)
+ yield from gen
+ def chat_completion(
+ self,
+ request: ChatCompletionRequest,
+ ) -> Generator:
+ req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
yield from gen
diff --git a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py
index 7dbedd0f0a..62eeefaacb 100644
--- a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py
+++ b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py
@@ -4,6 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
import json
import multiprocessing
import os
@@ -11,10 +17,9 @@
import time
import uuid
from enum import Enum
-from typing import Callable, Generator, List, Literal, Optional, Union
+from typing import Callable, Generator, Literal, Optional, Union
import torch
-
import zmq
from fairscale.nn.model_parallel.initialize import (
@@ -23,23 +28,14 @@
get_model_parallel_src_rank,
)
-from llama_models.llama3.api.datatypes import Message, ToolPromptFormat
-
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
-from .generation import TokenResult
+from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
-
-class InferenceArgs(BaseModel):
- messages: List[Message]
- temperature: float
- top_p: float
- max_gen_len: int
- logprobs: bool
- tool_prompt_format: ToolPromptFormat
+from .generation import TokenResult
class ProcessingMessageName(str, Enum):
@@ -80,7 +76,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
- task: InferenceArgs
+ task: Union[CompletionRequest, ChatCompletionRequest]
class TaskResponse(BaseModel):
@@ -349,11 +345,13 @@ def stop(self):
self.process.join()
self.started = False
- def run_inference(self, inference_args: InferenceArgs) -> Generator:
+ def run_inference(
+ self, req: Union[CompletionRequest, ChatCompletionRequest]
+ ) -> Generator:
assert not self.running, "inference already running"
self.running = True
- self.request_socket.send(encode_msg(TaskRequest(task=inference_args)))
+ self.request_socket.send(encode_msg(TaskRequest(task=req)))
try:
while True:
obj_json = self.request_socket.recv()
diff --git a/llama_stack/providers/impls/meta_reference/safety/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/llama_guard.py
index a6f450fae5..99b1c29be3 100644
--- a/llama_stack/providers/impls/meta_reference/safety/llama_guard.py
+++ b/llama_stack/providers/impls/meta_reference/safety/llama_guard.py
@@ -184,7 +184,7 @@ async def run(self, messages: List[Message]) -> ShieldResponse:
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
- async for chunk in self.inference_api.chat_completion(
+ async for chunk in await self.inference_api.chat_completion(
model=self.model,
messages=[shield_input_message],
stream=True,
diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py
index 5cdb1a2ab5..c977c738d3 100644
--- a/llama_stack/providers/impls/vllm/vllm.py
+++ b/llama_stack/providers/impls/vllm/vllm.py
@@ -134,7 +134,7 @@ async def shutdown(self):
if self.engine:
self.engine.shutdown_background_loop()
- def completion(
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -152,7 +152,7 @@ def completion(
logprobs=logprobs,
)
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: list[Message],
@@ -189,7 +189,7 @@ def chat_completion(
if stream:
return self._stream_chat_completion(request, results_generator)
else:
- return self._nonstream_chat_completion(request, results_generator)
+ return await self._nonstream_chat_completion(request, results_generator)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py
index 6774d3f1fc..9c34c3a28e 100644
--- a/llama_stack/providers/tests/agents/test_agents.py
+++ b/llama_stack/providers/tests/agents/test_agents.py
@@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages):
)
turn_response = [
- chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
+ chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments(
)
turn_response = [
- chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
+ chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments(
)
turn_response = [
- chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
+ chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
@@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search(
)
turn_response = [
- chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
+ chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
assert len(turn_response) > 0
diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py
index 581a0d4288..09d6a69db6 100644
--- a/llama_stack/providers/tests/inference/test_inference.py
+++ b/llama_stack/providers/tests/inference/test_inference.py
@@ -126,6 +126,45 @@ async def test_model_list(inference_settings):
assert model_def.identifier == params["model"]
+@pytest.mark.asyncio
+async def test_completion(inference_settings):
+ inference_impl = inference_settings["impl"]
+ params = inference_settings["common_params"]
+
+ provider = inference_impl.routing_table.get_provider_impl(params["model"])
+ if provider.__provider_id__ != "meta-reference":
+ pytest.skip("Other inference providers don't support completion() yet")
+
+ response = await inference_impl.completion(
+ content="Roses are red,",
+ stream=False,
+ model=params["model"],
+ sampling_params=SamplingParams(
+ max_tokens=50,
+ ),
+ )
+
+ assert isinstance(response, CompletionResponse)
+ assert "violets are blue" in response.content
+
+ chunks = [
+ r
+ async for r in await inference_impl.completion(
+ content="Roses are red,",
+ stream=True,
+ model=params["model"],
+ sampling_params=SamplingParams(
+ max_tokens=50,
+ ),
+ )
+ ]
+
+ assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
+ assert len(chunks) == 51
+ last = chunks[-1]
+ assert last.stop_reason == StopReason.out_of_tokens
+
+
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
@@ -146,7 +185,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages):
inference_impl = inference_settings["impl"]
response = [
r
- async for r in inference_impl.chat_completion(
+ async for r in await inference_impl.chat_completion(
messages=sample_messages,
stream=True,
**inference_settings["common_params"],
@@ -217,7 +256,7 @@ async def test_chat_completion_with_tool_calling_streaming(
response = [
r
- async for r in inference_impl.chat_completion(
+ async for r in await inference_impl.chat_completion(
messages=messages,
tools=[sample_tool_definition],
stream=True,