diff --git a/README.md b/README.md index dadafae904..29b475b910 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,10 @@ Our goal is to provide pre-packaged implementations which can be operated in a v > ⚠️ **Note** > The Stack APIs are rapidly improving, but still very much work in progress and we invite feedback as well as direct contributions. - ## APIs We have working implementations of the following APIs today: + - Inference - Safety - Memory @@ -74,19 +74,21 @@ There is a vibrant ecosystem of Providers which provide efficient inference or s Additionally, we have designed every element of the Stack such that APIs as well as Resources (like Models) can be federated. - ## Supported Llama Stack Implementations + ### API Providers + | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | -|:------------------------------------------------------------------------------------------:|:----------------------:|:------------------:|:------------------:|:------------------:|:------------------:|:------------------:| +| :----------------------------------------------------------------------------------------: | :--------------------: | :----------------: | :----------------: | :----------------: | :----------------: | :----------------: | | Meta Reference | Single Node | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | Cerebras | Hosted | | :heavy_check_mark: | | | | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | -| Ollama | Single Node | | :heavy_check_mark: | | | | -| TGI | Hosted and Single Node | | :heavy_check_mark: | | | | -| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | | +| Groq | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | +| Ollama | Single Node | | :heavy_check_mark: | | | +| TGI | Hosted and Single Node | | :heavy_check_mark: | | | +| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | | Chroma | Single Node | | | :heavy_check_mark: | | | | PG Vector | Single Node | | | :heavy_check_mark: | | | | PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | | | @@ -94,16 +96,16 @@ Additionally, we have designed every element of the Stack such that APIs as well ### Distributions -| **Distribution** | **Llama Stack Docker** | Start This Distribution | -|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:| -| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | -| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | -| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) | -| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) | -| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) | -| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) | -| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) | -| [vLLM](https://github.com/vllm-project/vllm) | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) | +| **Distribution** | **Llama Stack Docker** | Start This Distribution | +| :------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------: | +| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | +| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | +| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) | +| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) | +| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) | +| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) | +| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) | +| [vLLM](https://github.com/vllm-project/vllm) | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) | ## Installation @@ -111,6 +113,7 @@ You have two ways to install this repository: 1. **Install as a package**: You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command: + ```bash pip install llama-stack ``` @@ -118,6 +121,7 @@ You have two ways to install this repository: 2. **Install from source**: If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable). Then, follow these steps: + ```bash mkdir -p ~/local cd ~/local @@ -134,24 +138,24 @@ You have two ways to install this repository: Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details. -* [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html) - * Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution. -* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) - * Quick guide to start a Llama Stack server. - * [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs - * The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack). - * A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples. -* [Contributing](CONTRIBUTING.md) - * [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/contributing/new_api_provider.html) to walk-through how to add a new API provider. +- [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html) + - Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution. +- [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) + - Quick guide to start a Llama Stack server. + - [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs + - The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack). + - A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples. +- [Contributing](CONTRIBUTING.md) + - [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/contributing/new_api_provider.html) to walk-through how to add a new API provider. ## Llama Stack Client SDKs -| **Language** | **Client SDK** | **Package** | -| :----: | :----: | :----: | -| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/) -| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift) -| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client) -| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin) +| **Language** | **Client SDK** | **Package** | +| :----------: | :----------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/) | +| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift) | +| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client) | +| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin) | Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications. diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 0ff557b9f9..beaa626455 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -149,6 +149,15 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="groq", + pip_packages=["groq"], + module="llama_stack.providers.remote.inference.groq", + config_class="llama_stack.providers.remote.inference.groq.GroqConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/groq/__init__.py b/llama_stack/providers/remote/inference/groq/__init__.py new file mode 100644 index 0000000000..81c3ed7d3a --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import Inference + +from .config import GroqConfig + + +async def get_adapter_impl(config: GroqConfig, _deps) -> Inference: + # import dynamically so `llama stack build` does not fail due to missing dependencies + from .groq import GroqInferenceAdapter + + if not isinstance(config, GroqConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + + adapter = GroqInferenceAdapter(config) + return adapter diff --git a/llama_stack/providers/remote/inference/groq/config.py b/llama_stack/providers/remote/inference/groq/config.py new file mode 100644 index 0000000000..7c50234103 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/config.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class GroqConfig(BaseModel): + api_key: Optional[str] = Field( + # The Groq client library loads the GROQ_API_KEY environment variable by default + default=None, + description="The Groq API key", + ) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py new file mode 100644 index 0000000000..ccd1fc0c51 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import AsyncIterator, List, Optional, Union + +from groq import Groq +from llama_models.datatypes import SamplingParams +from llama_models.llama3.api.datatypes import ( + InterleavedTextMedia, + Message, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_models.sku_list import CoreModelId + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + ResponseFormat, +) +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.utils.inference.model_registry import ( + build_model_alias, + build_model_alias_with_just_provider_model_id, + ModelRegistryHelper, +) +from .groq_utils import ( + convert_chat_completion_request, + convert_chat_completion_response, + convert_chat_completion_response_stream, +) + +_MODEL_ALIASES = [ + build_model_alias( + "llama3-8b-8192", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias_with_just_provider_model_id( + "llama-3.1-8b-instant", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_model_alias( + "llama3-70b-8192", + CoreModelId.llama3_70b_instruct.value, + ), + build_model_alias( + "llama-3.3-70b-versatile", + CoreModelId.llama3_3_70b_instruct.value, + ), + # Groq only contains a preview version for llama-3.2-3b + # Preview models aren't recommended for production use, but we include this one + # to pass the test fixture + # TODO(aidand): Replace this with a stable model once Groq supports it + build_model_alias( + "llama-3.2-3b-preview", + CoreModelId.llama3_2_3b_instruct.value, + ), +] + + +class GroqInferenceAdapter(Inference, ModelRegistryHelper): + _client: Groq + + def __init__(self, config: GroqConfig): + ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) + self._client = Groq(api_key=config.api_key) + + def completion( + self, + model_id: str, + content: InterleavedTextMedia, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: + # Groq doesn't support non-chat completion as of time of writing + raise NotImplementedError() + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ + ToolPromptFormat + ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: + + if model_id == "llama-3.2-3b-preview": + warnings.warn( + "Groq only contains a preview version for llama-3.2-3b-instruct. " + "Preview models aren't recommended for production use. " + "They can be discontinued on short notice." + ) + + model_id = self.get_provider_model_id(model_id) + request = convert_chat_completion_request( + request=ChatCompletionRequest( + model=model_id, + messages=messages, + sampling_params=sampling_params, + response_format=response_format, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + ) + + response = self._client.chat.completions.create(**request) + + if stream: + return convert_chat_completion_response_stream(response) + else: + return convert_chat_completion_response(response) + + async def embeddings( + self, + model_id: str, + contents: List[InterleavedTextMedia], + ) -> EmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py new file mode 100644 index 0000000000..ee1e3c5e11 --- /dev/null +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from typing import AsyncGenerator, Generator, Literal + +from groq import Stream +from groq.types.chat.chat_completion import ChatCompletion +from groq.types.chat.chat_completion_assistant_message_param import ( + ChatCompletionAssistantMessageParam, +) +from groq.types.chat.chat_completion_chunk import ChatCompletionChunk +from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from groq.types.chat.chat_completion_system_message_param import ( + ChatCompletionSystemMessageParam, +) +from groq.types.chat.chat_completion_user_message_param import ( + ChatCompletionUserMessageParam, +) + +from groq.types.chat.completion_create_params import CompletionCreateParams + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, + Message, + Role, + StopReason, +) + + +def convert_chat_completion_request( + request: ChatCompletionRequest, +) -> CompletionCreateParams: + """ + Convert a ChatCompletionRequest to a Groq API-compatible dictionary. + Warns client if request contains unsupported features. + """ + + if request.logprobs: + # Groq doesn't support logprobs at the time of writing + warnings.warn("logprobs are not supported yet") + + if request.response_format: + # Groq's JSON mode is beta at the time of writing + warnings.warn("response_format is not supported yet") + + if request.sampling_params.repetition_penalty: + # groq supports frequency_penalty, but frequency_penalty and sampling_params.repetition_penalty + # seem to have different semantics + # frequency_penalty defaults to 0 is a float between -2.0 and 2.0 + # repetition_penalty defaults to 1 and is often set somewhere between 1.0 and 2.0 + # so we exclude it for now + warnings.warn("repetition_penalty is not supported") + + if request.tools: + warnings.warn("tools are not supported yet") + + return CompletionCreateParams( + model=request.model, + messages=[_convert_message(message) for message in request.messages], + logprobs=None, + frequency_penalty=None, + stream=request.stream, + max_tokens=request.sampling_params.max_tokens or None, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + ) + + +def _convert_message(message: Message) -> ChatCompletionMessageParam: + if message.role == Role.system.value: + return ChatCompletionSystemMessageParam(role="system", content=message.content) + elif message.role == Role.user.value: + return ChatCompletionUserMessageParam(role="user", content=message.content) + elif message.role == Role.assistant.value: + return ChatCompletionAssistantMessageParam( + role="assistant", content=message.content + ) + else: + raise ValueError(f"Invalid message role: {message.role}") + + +def convert_chat_completion_response( + response: ChatCompletion, +) -> ChatCompletionResponse: + # groq only supports n=1 at time of writing, so there is only one choice + choice = response.choices[0] + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content, + stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), + ), + ) + + +def _map_finish_reason_to_stop_reason( + finish_reason: Literal["stop", "length", "tool_calls"] +) -> StopReason: + """ + Convert a Groq chat completion finish_reason to a StopReason. + + finish_reason: Literal["stop", "length", "tool_calls"] + - stop -> model hit a natural stop point or a provided stop sequence + - length -> maximum number of tokens specified in the request was reached + - tool_calls -> model called a tool + """ + if finish_reason == "stop": + return StopReason.end_of_turn + elif finish_reason == "length": + return StopReason.end_of_message + elif finish_reason == "tool_calls": + raise NotImplementedError("tool_calls is not supported yet") + else: + raise ValueError(f"Invalid finish reason: {finish_reason}") + + +async def convert_chat_completion_response_stream( + stream: Stream[ChatCompletionChunk], +) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + + # generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... + def _event_type_generator() -> ( + Generator[ChatCompletionResponseEventType, None, None] + ): + yield ChatCompletionResponseEventType.start + while True: + yield ChatCompletionResponseEventType.progress + + event_types = _event_type_generator() + + for chunk in stream: + choice = chunk.choices[0] + + # We assume there's only one finish_reason for the entire stream. + # We collect the last finish_reason + if choice.finish_reason: + stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=next(event_types), + delta=choice.delta.content or "", + logprobs=None, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + logprobs=None, + stop_reason=stop_reason, + ) + ) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index d9c0cb1889..194c741f8d 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -19,6 +19,7 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig +from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig @@ -150,6 +151,22 @@ def inference_together() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def inference_groq() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="groq", + provider_type="remote::groq", + config=GroqConfig().model_dump(), + ) + ], + provider_data=dict( + groq_api_key=get_env_or_fail("GROQ_API_KEY"), + ), + ) + + @pytest.fixture(scope="session") def inference_bedrock() -> ProviderFixture: return ProviderFixture( @@ -222,6 +239,7 @@ def model_id(inference_model) -> str: "ollama", "fireworks", "together", + "groq", "vllm_remote", "remote", "bedrock", diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py new file mode 100644 index 0000000000..930ad54eb9 --- /dev/null +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from groq.types.chat.chat_completion import ChatCompletion, Choice +from groq.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice as StreamChoice, + ChoiceDelta, +) +from groq.types.chat.chat_completion_message import ChatCompletionMessage + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponseEventType, + CompletionMessage, + StopReason, + SystemMessage, + UserMessage, +) +from llama_stack.providers.remote.inference.groq.groq_utils import ( + convert_chat_completion_request, + convert_chat_completion_response, + convert_chat_completion_response_stream, +) + + +class TestConvertChatCompletionRequest: + def test_sets_model(self): + request = self._dummy_chat_completion_request() + request.model = "Llama-3.2-3B" + + converted = convert_chat_completion_request(request) + + assert converted["model"] == "Llama-3.2-3B" + + def test_converts_user_message(self): + request = self._dummy_chat_completion_request() + request.messages = [UserMessage(content="Hello World")] + + converted = convert_chat_completion_request(request) + + assert converted["messages"] == [ + {"role": "user", "content": "Hello World"}, + ] + + def test_converts_system_message(self): + request = self._dummy_chat_completion_request() + request.messages = [SystemMessage(content="You are a helpful assistant.")] + + converted = convert_chat_completion_request(request) + + assert converted["messages"] == [ + {"role": "system", "content": "You are a helpful assistant."}, + ] + + def test_converts_completion_message(self): + request = self._dummy_chat_completion_request() + request.messages = [ + UserMessage(content="Hello World"), + CompletionMessage( + content="Hello World! How can I help you today?", + stop_reason=StopReason.end_of_message, + ), + ] + + converted = convert_chat_completion_request(request) + + assert converted["messages"] == [ + {"role": "user", "content": "Hello World"}, + {"role": "assistant", "content": "Hello World! How can I help you today?"}, + ] + + def test_does_not_include_logprobs(self): + request = self._dummy_chat_completion_request() + request.logprobs = True + + with pytest.warns(Warning) as warnings: + converted = convert_chat_completion_request(request) + + assert "logprobs are not supported yet" in warnings[0].message.args[0] + assert converted.get("logprobs") is None + + def test_does_not_include_response_format(self): + request = self._dummy_chat_completion_request() + request.response_format = { + "type": "json_object", + "json_schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number"}, + }, + }, + } + + with pytest.warns(Warning) as warnings: + converted = convert_chat_completion_request(request) + + assert "response_format is not supported yet" in warnings[0].message.args[0] + assert converted.get("response_format") is None + + def test_does_not_include_repetition_penalty(self): + request = self._dummy_chat_completion_request() + request.sampling_params.repetition_penalty = 1.5 + + with pytest.warns(Warning) as warnings: + converted = convert_chat_completion_request(request) + + assert "repetition_penalty is not supported" in warnings[0].message.args[0] + assert converted.get("repetition_penalty") is None + assert converted.get("frequency_penalty") is None + + def test_includes_stream(self): + request = self._dummy_chat_completion_request() + request.stream = True + + converted = convert_chat_completion_request(request) + + assert converted["stream"] is True + + def test_n_is_1(self): + request = self._dummy_chat_completion_request() + + converted = convert_chat_completion_request(request) + + assert converted["n"] == 1 + + def test_if_max_tokens_is_0_then_it_is_not_included(self): + request = self._dummy_chat_completion_request() + # 0 is the default value for max_tokens + # So we assume that if it's 0, the user didn't set it + request.sampling_params.max_tokens = 0 + + converted = convert_chat_completion_request(request) + + assert converted.get("max_tokens") is None + + def test_includes_max_tokens_if_set(self): + request = self._dummy_chat_completion_request() + request.sampling_params.max_tokens = 100 + + converted = convert_chat_completion_request(request) + + assert converted["max_tokens"] == 100 + + def _dummy_chat_completion_request(self): + return ChatCompletionRequest( + model="Llama-3.2-3B", + messages=[UserMessage(content="Hello World")], + ) + + def test_includes_temperature(self): + request = self._dummy_chat_completion_request() + request.sampling_params.temperature = 0.5 + + converted = convert_chat_completion_request(request) + + assert converted["temperature"] == 0.5 + + def test_includes_top_p(self): + request = self._dummy_chat_completion_request() + request.sampling_params.top_p = 0.95 + + converted = convert_chat_completion_request(request) + + assert converted["top_p"] == 0.95 + + +class TestConvertNonStreamChatCompletionResponse: + def test_returns_response(self): + response = self._dummy_chat_completion_response() + response.choices[0].message.content = "Hello World" + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.content == "Hello World" + + def test_maps_stop_to_end_of_message(self): + response = self._dummy_chat_completion_response() + response.choices[0].finish_reason = "stop" + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.stop_reason == StopReason.end_of_turn + + def test_maps_length_to_end_of_message(self): + response = self._dummy_chat_completion_response() + response.choices[0].finish_reason = "length" + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.stop_reason == StopReason.end_of_message + + def _dummy_chat_completion_response(self): + return ChatCompletion( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", content="Hello World" + ), + finish_reason="stop", + ) + ], + created=1729382400, + object="chat.completion", + ) + + +class TestConvertStreamChatCompletionResponse: + @pytest.mark.asyncio + async def test_returns_stream(self): + def chat_completion_stream(): + messages = ["Hello ", "World ", " !"] + for i, message in enumerate(messages): + chunk = self._dummy_chat_completion_chunk() + chunk.choices[0].delta.content = message + if i == len(messages) - 1: + chunk.choices[0].finish_reason = "stop" + else: + chunk.choices[0].finish_reason = None + yield chunk + + chunk = self._dummy_chat_completion_chunk() + chunk.choices[0].delta.content = None + chunk.choices[0].finish_reason = "stop" + yield chunk + + stream = chat_completion_stream() + converted = convert_chat_completion_response_stream(stream) + + iter = converted.__aiter__() + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.start + assert chunk.event.delta == "Hello " + + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.progress + assert chunk.event.delta == "World " + + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.progress + assert chunk.event.delta == " !" + + # Dummy chunk to ensure the last chunk is really the end of the stream + # This one technically maps to Groq's final "stop" chunk + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.progress + assert chunk.event.delta == "" + + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.complete + assert chunk.event.delta == "" + assert chunk.event.stop_reason == StopReason.end_of_turn + + with pytest.raises(StopAsyncIteration): + await iter.__anext__() + + def _dummy_chat_completion_chunk(self): + return ChatCompletionChunk( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + StreamChoice( + index=0, + delta=ChoiceDelta(role="assistant", content="Hello World"), + ) + ], + created=1729382400, + object="chat.completion.chunk", + x_groq=None, + ) diff --git a/llama_stack/providers/tests/inference/groq/test_init.py b/llama_stack/providers/tests/inference/groq/test_init.py new file mode 100644 index 0000000000..d23af5934d --- /dev/null +++ b/llama_stack/providers/tests/inference/groq/test_init.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from llama_stack.apis.inference import Inference +from llama_stack.providers.remote.inference.groq import get_adapter_impl +from llama_stack.providers.remote.inference.groq.config import GroqConfig +from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter + +from llama_stack.providers.remote.inference.ollama import OllamaImplConfig + + +class TestGroqInit: + @pytest.mark.asyncio + async def test_raises_runtime_error_if_config_is_not_groq_config(self): + config = OllamaImplConfig(model="llama3.1-8b-8192") + + with pytest.raises(RuntimeError): + await get_adapter_impl(config, None) + + @pytest.mark.asyncio + async def test_returns_groq_adapter(self): + config = GroqConfig() + adapter = await get_adapter_impl(config, None) + assert type(adapter) is GroqInferenceAdapter + assert isinstance(adapter, Inference) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99a62ac080..02851830b8 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -350,6 +350,14 @@ async def test_chat_completion_with_tool_calling( sample_messages, sample_tool_definition, ): + inference_impl, _ = inference_stack + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type in ("remote::groq",): + pytest.skip( + provider.__provider_spec__.provider_type + + " doesn't support tool calling yet" + ) + inference_impl, _ = inference_stack messages = sample_messages + [ UserMessage( @@ -390,6 +398,13 @@ async def test_chat_completion_with_tool_calling_streaming( sample_tool_definition, ): inference_impl, _ = inference_stack + provider = inference_impl.routing_table.get_provider_impl(inference_model) + if provider.__provider_spec__.provider_type in ("remote::groq",): + pytest.skip( + provider.__provider_spec__.provider_type + + " doesn't support tool calling yet" + ) + messages = sample_messages + [ UserMessage( content="What's the weather like in San Francisco?",