diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index a2f92b6e42..8e66839316 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988" }, "servers": [ { @@ -2830,8 +2830,11 @@ "CompletionResponse": { "type": "object", "properties": { - "completion_message": { - "$ref": "#/components/schemas/CompletionMessage" + "content": { + "type": "string" + }, + "stop_reason": { + "$ref": "#/components/schemas/StopReason" }, "logprobs": { "type": "array", @@ -2842,7 +2845,8 @@ }, "additionalProperties": false, "required": [ - "completion_message" + "content", + "stop_reason" ], "title": "Completion response." }, @@ -6075,49 +6079,49 @@ ], "tags": [ { - "name": "Evaluations" - }, - { - "name": "Inspect" + "name": "Models" }, { "name": "RewardScoring" }, { - "name": "Datasets" + "name": "MemoryBanks" }, { - "name": "Models" + "name": "Shields" }, { - "name": "Telemetry" + "name": "SyntheticDataGeneration" }, { - "name": "PostTraining" + "name": "Inference" }, { - "name": "SyntheticDataGeneration" + "name": "Inspect" }, { "name": "BatchInference" }, { - "name": "Inference" + "name": "Memory" + }, + { + "name": "Datasets" }, { "name": "Agents" }, { - "name": "Memory" + "name": "PostTraining" }, { - "name": "Safety" + "name": "Telemetry" }, { - "name": "Shields" + "name": "Safety" }, { - "name": "MemoryBanks" + "name": "Evaluations" }, { "name": "BuiltinTool", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index c9822d6ca9..906d3934a7 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -501,14 +501,17 @@ components: CompletionResponse: additionalProperties: false properties: - completion_message: - $ref: '#/components/schemas/CompletionMessage' + content: + type: string logprobs: items: $ref: '#/components/schemas/TokenLogProbs' type: array + stop_reason: + $ref: '#/components/schemas/StopReason' required: - - completion_message + - content + - stop_reason title: Completion response. type: object CompletionResponseStreamChunk: @@ -2507,7 +2510,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-10-10 15:29:56.831109" + \ draft and subject to change.\n Generated at 2024-10-18 20:48:17.730988" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -3712,21 +3715,21 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Evaluations -- name: Inspect -- name: RewardScoring -- name: Datasets - name: Models -- name: Telemetry -- name: PostTraining +- name: RewardScoring +- name: MemoryBanks +- name: Shields - name: SyntheticDataGeneration -- name: BatchInference - name: Inference -- name: Agents +- name: Inspect +- name: BatchInference - name: Memory +- name: Datasets +- name: Agents +- name: PostTraining +- name: Telemetry - name: Safety -- name: Shields -- name: MemoryBanks +- name: Evaluations - description: name: BuiltinTool - description: AgentCreateResponse: ... - # This method is not `async def` because it can result in either an - # `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`. @webmethod(route="/agents/turn/create") - def create_agent_turn( + async def create_agent_turn( self, agent_id: str, session_id: str, diff --git a/llama_stack/apis/agents/client.py b/llama_stack/apis/agents/client.py index 32bc9abdd5..b454473289 100644 --- a/llama_stack/apis/agents/client.py +++ b/llama_stack/apis/agents/client.py @@ -67,14 +67,14 @@ async def create_agent_session( response.raise_for_status() return AgentSessionCreateResponse(**response.json()) - def create_agent_turn( + async def create_agent_turn( self, request: AgentTurnCreateRequest, ) -> AsyncGenerator: if request.stream: return self._stream_agent_turn(request) else: - return self._nonstream_agent_turn(request) + return await self._nonstream_agent_turn(request) async def _stream_agent_turn( self, request: AgentTurnCreateRequest @@ -126,7 +126,7 @@ async def _run_agent( for content in user_prompts: cprint(f"User> {content}", color="white", attrs=["bold"]) - iterator = api.create_agent_turn( + iterator = await api.create_agent_turn( AgentTurnCreateRequest( agent_id=create_response.agent_id, session_id=session_response.session_id, diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index 79d2cc02ca..90636fa363 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -42,10 +42,10 @@ async def initialize(self) -> None: async def shutdown(self) -> None: pass - def completion(self, request: CompletionRequest) -> AsyncGenerator: + async def completion(self, request: CompletionRequest) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -139,7 +139,8 @@ async def run_main( else: logprobs_config = None - iterator = client.chat_completion( + assert stream, "Non streaming not supported here" + iterator = await client.chat_completion( model=model, messages=[message], stream=stream, diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 588dd37caa..5895e528e5 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -88,7 +88,8 @@ class CompletionRequest(BaseModel): class CompletionResponse(BaseModel): """Completion response.""" - completion_message: CompletionMessage + content: str + stop_reason: StopReason logprobs: Optional[List[TokenLogProbs]] = None @@ -113,7 +114,7 @@ class BatchCompletionRequest(BaseModel): class BatchCompletionResponse(BaseModel): """Batch completion response.""" - completion_message_batch: List[CompletionMessage] + batch: List[CompletionResponse] @json_schema_type @@ -165,7 +166,7 @@ class BatchChatCompletionRequest(BaseModel): @json_schema_type class BatchChatCompletionResponse(BaseModel): - completion_message_batch: List[CompletionMessage] + batch: List[ChatCompletionResponse] @json_schema_type @@ -181,10 +182,8 @@ def get_model(self, identifier: str) -> ModelDef: ... class Inference(Protocol): model_store: ModelStore - # This method is not `async def` because it can result in either an - # `AsyncGenerator` or a `CompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/completion") - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -196,7 +195,7 @@ def completion( # This method is not `async def` because it can result in either an # `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`. @webmethod(route="/inference/chat_completion") - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index cf62da1d01..a78e808d08 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -70,7 +70,7 @@ async def shutdown(self) -> None: async def register_model(self, model: ModelDef) -> None: await self.routing_table.register_model(model) - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -93,11 +93,11 @@ def chat_completion( ) provider = self.routing_table.get_provider_impl(model) if stream: - return (chunk async for chunk in provider.chat_completion(**params)) + return (chunk async for chunk in await provider.chat_completion(**params)) else: - return provider.chat_completion(**params) + return await provider.chat_completion(**params) - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -114,9 +114,9 @@ def completion( logprobs=logprobs, ) if stream: - return (chunk async for chunk in provider.completion(**params)) + return (chunk async for chunk in await provider.completion(**params)) else: - return provider.completion(**params) + return await provider.completion(**params) async def embeddings( self, diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 22f87ef6bd..8440ecc205 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -47,7 +47,7 @@ async def initialize(self) -> None: async def shutdown(self) -> None: self.client.close() - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -283,7 +283,7 @@ def _tools_to_tool_config( ) return tool_config - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], diff --git a/llama_stack/providers/adapters/inference/databricks/databricks.py b/llama_stack/providers/adapters/inference/databricks/databricks.py index 1410511864..9f50ad227d 100644 --- a/llama_stack/providers/adapters/inference/databricks/databricks.py +++ b/llama_stack/providers/adapters/inference/databricks/databricks.py @@ -48,7 +48,7 @@ async def initialize(self) -> None: async def shutdown(self) -> None: pass - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -58,7 +58,7 @@ def completion( ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -84,7 +84,7 @@ def chat_completion( if stream: return self._stream_chat_completion(request, client) else: - return self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index c82012cba6..537f3a6b4f 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -51,7 +51,7 @@ async def initialize(self) -> None: async def shutdown(self) -> None: pass - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -61,7 +61,7 @@ def completion( ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -87,7 +87,7 @@ def chat_completion( if stream: return self._stream_chat_completion(request, client) else: - return self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: Fireworks diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index c50c869fd5..3a3e4b4516 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -84,7 +84,7 @@ async def list_models(self) -> List[ModelDef]: return ret - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -94,7 +94,7 @@ def completion( ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -118,7 +118,7 @@ def chat_completion( if stream: return self._stream_chat_completion(request) else: - return self._nonstream_chat_completion(request) + return await self._nonstream_chat_completion(request) def _get_params(self, request: ChatCompletionRequest) -> dict: return { diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index cd0afad0cc..3c610099cd 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -66,7 +66,7 @@ async def list_models(self) -> List[ModelDef]: async def shutdown(self) -> None: pass - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -76,7 +76,7 @@ def completion( ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -101,7 +101,7 @@ def chat_completion( if stream: return self._stream_chat_completion(request) else: - return self._nonstream_chat_completion(request) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( self, request: ChatCompletionRequest diff --git a/llama_stack/providers/adapters/inference/together/together.py b/llama_stack/providers/adapters/inference/together/together.py index 750ca126e2..8c73d75ecc 100644 --- a/llama_stack/providers/adapters/inference/together/together.py +++ b/llama_stack/providers/adapters/inference/together/together.py @@ -64,7 +64,7 @@ async def completion( ) -> AsyncGenerator: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -101,7 +101,7 @@ def chat_completion( if stream: return self._stream_chat_completion(request, client) else: - return self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request, client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: Together diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 0d334fdad7..cbc7490fd8 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -424,7 +424,7 @@ async def _run( stop_reason = None with tracing.span("inference"): - async for chunk in self.inference_api.chat_completion( + async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, tools=self._get_tools(), diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 5a209d0b76..8b3ece978f 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -105,7 +105,7 @@ async def create_agent_session( session_id=session_id, ) - def create_agent_turn( + async def create_agent_turn( self, agent_id: str, session_id: str, diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 20a8addc7f..9ca1281768 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -23,11 +23,6 @@ ) from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.chat_format import ChatFormat, ModelInput -from llama_models.llama3.api.datatypes import ( - InterleavedTextMedia, - Message, - ToolPromptFormat, -) from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer from llama_models.llama3.reference_impl.multimodal.model import ( @@ -38,7 +33,11 @@ from pydantic import BaseModel from termcolor import cprint +from llama_stack.apis.inference import * # noqa: F403 from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_messages, +) from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig @@ -297,15 +296,12 @@ def generate( if all(eos_reached): break - def text_completion( + def completion( self, - content: InterleavedTextMedia, - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - echo: bool = False, + request: CompletionRequest, ) -> Generator: + sampling_params = request.sampling_params + max_gen_len = sampling_params.max_tokens if ( max_gen_len is None or max_gen_len == 0 @@ -313,26 +309,25 @@ def text_completion( ): max_gen_len = self.model.params.max_seq_len - 1 - model_input = self.formatter.encode_content(content) - + model_input = self.formatter.encode_content(request.content) yield from self.generate( model_input=model_input, max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, - echo=echo, + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + logprobs=bool(request.logprobs), + include_stop_token=True, + echo=False, ) def chat_completion( self, - messages: List[Message], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + request: ChatCompletionRequest, ) -> Generator: + messages = chat_completion_request_to_messages(request) + + sampling_params = request.sampling_params + max_gen_len = sampling_params.max_tokens if ( max_gen_len is None or max_gen_len == 0 @@ -343,12 +338,12 @@ def chat_completion( yield from self.generate( model_input=self.formatter.encode_dialog_prompt( messages, - tool_prompt_format, + request.tool_prompt_format, ), max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=logprobs, + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + logprobs=bool(request.logprobs), include_stop_token=True, ) diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index 7edc279d03..34053343e6 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -13,9 +13,6 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_messages, -) from .config import MetaReferenceInferenceConfig from .generation import Llama @@ -58,7 +55,18 @@ async def shutdown(self) -> None: if self.config.create_distributed_process_group: self.generator.stop() - def completion( + def check_model(self, request) -> None: + model = resolve_model(request.model) + if model is None: + raise RuntimeError( + f"Unknown model: {request.model}, Run `llama model list`" + ) + elif model.descriptor() != self.model.descriptor(): + raise RuntimeError( + f"Model mismatch: {request.model} != {self.model.descriptor()}" + ) + + async def completion( self, model: str, content: InterleavedTextMedia, @@ -66,9 +74,114 @@ def completion( stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: - raise NotImplementedError() + if logprobs: + assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" + + request = CompletionRequest( + model=model, + content=content, + sampling_params=sampling_params, + stream=stream, + logprobs=logprobs, + ) + self.check_model(request) + + if request.stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + def impl(): + stop_reason = None + + for token_result in self.generator.completion(request): + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + else: + text = token_result.text + + logprobs = None + if stop_reason is None: + if request.logprobs: + assert len(token_result.logprobs) == 1 + + logprobs = [ + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ] - def chat_completion( + yield CompletionResponseStreamChunk( + delta=text, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) + + if stop_reason is None: + yield CompletionResponseStreamChunk( + delta="", + stop_reason=StopReason.out_of_tokens, + ) + + if self.config.create_distributed_process_group: + async with SEMAPHORE: + for x in impl(): + yield x + else: + for x in impl(): + yield x + + async def _nonstream_completion( + self, request: CompletionRequest + ) -> CompletionResponse: + def impl(): + tokens = [] + logprobs = [] + stop_reason = None + + tokenizer = self.generator.formatter.tokenizer + for token_result in self.generator.completion(request): + tokens.append(token_result.token) + + if token_result.token in tokenizer.stop_tokens: + # not quite right semantically + stop_reason = StopReason.end_of_turn + + if request.logprobs: + assert len(token_result.logprobs) == 1 + + logprobs.append( + TokenLogProbs( + logprobs_by_token={ + token_result.text: token_result.logprobs[0] + } + ) + ) + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + content = self.generator.formatter.tokenizer.decode(tokens) + return CompletionResponse( + content=content, + stop_reason=stop_reason, + logprobs=logprobs if request.logprobs else None, + ) + + if self.config.create_distributed_process_group: + async with SEMAPHORE: + return impl() + else: + return impl() + + async def chat_completion( self, model: str, messages: List[Message], @@ -93,16 +206,7 @@ def chat_completion( stream=stream, logprobs=logprobs, ) - - model = resolve_model(request.model) - if model is None: - raise RuntimeError( - f"Unknown model: {request.model}, Run `llama model list`" - ) - elif model.descriptor() != self.model.descriptor(): - raise RuntimeError( - f"Model mismatch: {request.model} != {self.model.descriptor()}" - ) + self.check_model(request) if self.config.create_distributed_process_group: if SEMAPHORE.locked(): @@ -111,26 +215,17 @@ def chat_completion( if request.stream: return self._stream_chat_completion(request) else: - return self._nonstream_chat_completion(request) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: def impl(): - messages = chat_completion_request_to_messages(request) - tokens = [] logprobs = [] stop_reason = None - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): + for token_result in self.generator.chat_completion(request): tokens.append(token_result.token) if token_result.text == "<|eot_id|>": @@ -170,8 +265,6 @@ async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: def impl(): - messages = chat_completion_request_to_messages(request) - yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, @@ -184,14 +277,7 @@ def impl(): stop_reason = None ipython = False - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): + for token_result in self.generator.chat_completion(request): tokens.append(token_result.token) if not ipython and token_result.text.startswith("<|python_tag|>"): diff --git a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py index e8f483f300..7e7831185b 100644 --- a/llama_stack/providers/impls/meta_reference/inference/model_parallel.py +++ b/llama_stack/providers/impls/meta_reference/inference/model_parallel.py @@ -7,16 +7,17 @@ import os from copy import deepcopy from functools import partial -from typing import Generator, List, Optional +from typing import Any, Generator from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import Message, ToolPromptFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model +from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest + from .config import MetaReferenceInferenceConfig from .generation import Llama, model_checkpoint_dir -from .parallel_utils import InferenceArgs, ModelParallelProcessGroup +from .parallel_utils import ModelParallelProcessGroup class ModelRunner: @@ -24,15 +25,13 @@ def __init__(self, llama): self.llama = llama # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` - def __call__(self, task: InferenceArgs): - return self.llama.chat_completion( - task.messages, - task.temperature, - task.top_p, - task.max_gen_len, - task.logprobs, - task.tool_prompt_format, - ) + def __call__(self, req: Any): + if isinstance(req, ChatCompletionRequest): + return self.llama.chat_completion(req) + elif isinstance(req, CompletionRequest): + return self.llama.completion(req) + else: + raise ValueError(f"Unexpected task type {type(req)}") def init_model_cb(config: MetaReferenceInferenceConfig): @@ -77,23 +76,18 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): self.group.stop() - def chat_completion( + def completion( self, - messages: List[Message], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + request: CompletionRequest, ) -> Generator: - req_obj = InferenceArgs( - messages=deepcopy(messages), - temperature=temperature, - top_p=top_p, - max_gen_len=max_gen_len, - logprobs=logprobs or False, - tool_prompt_format=tool_prompt_format, - ) + req_obj = deepcopy(request) + gen = self.group.run_inference(req_obj) + yield from gen + def chat_completion( + self, + request: ChatCompletionRequest, + ) -> Generator: + req_obj = deepcopy(request) gen = self.group.run_inference(req_obj) yield from gen diff --git a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py index 7dbedd0f0a..62eeefaacb 100644 --- a/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py +++ b/llama_stack/providers/impls/meta_reference/inference/parallel_utils.py @@ -4,6 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# Copyright (c) Meta Platforms, IAny, nc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + import json import multiprocessing import os @@ -11,10 +17,9 @@ import time import uuid from enum import Enum -from typing import Callable, Generator, List, Literal, Optional, Union +from typing import Callable, Generator, Literal, Optional, Union import torch - import zmq from fairscale.nn.model_parallel.initialize import ( @@ -23,23 +28,14 @@ get_model_parallel_src_rank, ) -from llama_models.llama3.api.datatypes import Message, ToolPromptFormat - from pydantic import BaseModel, Field from torch.distributed.launcher.api import elastic_launch, LaunchConfig from typing_extensions import Annotated -from .generation import TokenResult +from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest - -class InferenceArgs(BaseModel): - messages: List[Message] - temperature: float - top_p: float - max_gen_len: int - logprobs: bool - tool_prompt_format: ToolPromptFormat +from .generation import TokenResult class ProcessingMessageName(str, Enum): @@ -80,7 +76,7 @@ class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ( ProcessingMessageName.task_request ) - task: InferenceArgs + task: Union[CompletionRequest, ChatCompletionRequest] class TaskResponse(BaseModel): @@ -349,11 +345,13 @@ def stop(self): self.process.join() self.started = False - def run_inference(self, inference_args: InferenceArgs) -> Generator: + def run_inference( + self, req: Union[CompletionRequest, ChatCompletionRequest] + ) -> Generator: assert not self.running, "inference already running" self.running = True - self.request_socket.send(encode_msg(TaskRequest(task=inference_args))) + self.request_socket.send(encode_msg(TaskRequest(task=req))) try: while True: obj_json = self.request_socket.recv() diff --git a/llama_stack/providers/impls/meta_reference/safety/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/llama_guard.py index a6f450fae5..99b1c29be3 100644 --- a/llama_stack/providers/impls/meta_reference/safety/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/llama_guard.py @@ -184,7 +184,7 @@ async def run(self, messages: List[Message]) -> ShieldResponse: # TODO: llama-stack inference protocol has issues with non-streaming inference code content = "" - async for chunk in self.inference_api.chat_completion( + async for chunk in await self.inference_api.chat_completion( model=self.model, messages=[shield_input_message], stream=True, diff --git a/llama_stack/providers/impls/vllm/vllm.py b/llama_stack/providers/impls/vllm/vllm.py index 5cdb1a2ab5..c977c738d3 100644 --- a/llama_stack/providers/impls/vllm/vllm.py +++ b/llama_stack/providers/impls/vllm/vllm.py @@ -134,7 +134,7 @@ async def shutdown(self): if self.engine: self.engine.shutdown_background_loop() - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -152,7 +152,7 @@ def completion( logprobs=logprobs, ) - def chat_completion( + async def chat_completion( self, model: str, messages: list[Message], @@ -189,7 +189,7 @@ def chat_completion( if stream: return self._stream_chat_completion(request, results_generator) else: - return self._nonstream_chat_completion(request, results_generator) + return await self._nonstream_chat_completion(request, results_generator) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, results_generator: AsyncGenerator diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 6774d3f1fc..9c34c3a28e 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -116,7 +116,7 @@ async def test_create_agent_turn(agents_settings, sample_messages): ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -204,7 +204,7 @@ async def test_rag_agent_as_attachments( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -218,7 +218,7 @@ async def test_rag_agent_as_attachments( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 @@ -270,7 +270,7 @@ async def test_create_agent_turn_with_brave_search( ) turn_response = [ - chunk async for chunk in agents_impl.create_agent_turn(**turn_request) + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] assert len(turn_response) > 0 diff --git a/llama_stack/providers/tests/inference/test_inference.py b/llama_stack/providers/tests/inference/test_inference.py index 581a0d4288..09d6a69db6 100644 --- a/llama_stack/providers/tests/inference/test_inference.py +++ b/llama_stack/providers/tests/inference/test_inference.py @@ -126,6 +126,45 @@ async def test_model_list(inference_settings): assert model_def.identifier == params["model"] +@pytest.mark.asyncio +async def test_completion(inference_settings): + inference_impl = inference_settings["impl"] + params = inference_settings["common_params"] + + provider = inference_impl.routing_table.get_provider_impl(params["model"]) + if provider.__provider_id__ != "meta-reference": + pytest.skip("Other inference providers don't support completion() yet") + + response = await inference_impl.completion( + content="Roses are red,", + stream=False, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + ) + + assert isinstance(response, CompletionResponse) + assert "violets are blue" in response.content + + chunks = [ + r + async for r in await inference_impl.completion( + content="Roses are red,", + stream=True, + model=params["model"], + sampling_params=SamplingParams( + max_tokens=50, + ), + ) + ] + + assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks) + assert len(chunks) == 51 + last = chunks[-1] + assert last.stop_reason == StopReason.out_of_tokens + + @pytest.mark.asyncio async def test_chat_completion_non_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] @@ -146,7 +185,7 @@ async def test_chat_completion_streaming(inference_settings, sample_messages): inference_impl = inference_settings["impl"] response = [ r - async for r in inference_impl.chat_completion( + async for r in await inference_impl.chat_completion( messages=sample_messages, stream=True, **inference_settings["common_params"], @@ -217,7 +256,7 @@ async def test_chat_completion_with_tool_calling_streaming( response = [ r - async for r in inference_impl.chat_completion( + async for r in await inference_impl.chat_completion( messages=messages, tools=[sample_tool_definition], stream=True,