From 77d32c21152551fb6d66b1b4793723deb73efeee Mon Sep 17 00:00:00 2001 From: Aidan Do Date: Sat, 14 Dec 2024 20:48:38 +1100 Subject: [PATCH] Add tool calls to groq inference adapter --- .../providers/remote/inference/groq/groq.py | 10 +- .../remote/inference/groq/groq_utils.py | 140 +++++++-- .../tests/inference/groq/test_groq_utils.py | 281 ++++++++++++++++-- .../tests/inference/test_text_inference.py | 24 +- 4 files changed, 398 insertions(+), 57 deletions(-) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 1a19b4d79e..22369e7e6b 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -7,6 +7,7 @@ import warnings from typing import AsyncIterator, List, Optional, Union +import groq from groq import Groq from llama_models.datatypes import SamplingParams from llama_models.llama3.api.datatypes import ToolDefinition, ToolPromptFormat @@ -124,7 +125,14 @@ async def chat_completion( ) ) - response = self._get_client().chat.completions.create(**request) + try: + response = self._get_client().chat.completions.create(**request) + except groq.BadRequestError as e: + if e.body.get("error", {}).get("code") == "tool_use_failed": + # For smaller models, Groq may fail to call a tool even when the request is well formed + raise ValueError("Groq failed to call a tool", e.body.get("error", {})) + else: + raise e if stream: return convert_chat_completion_response_stream(response) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 74c6178a39..032f4c8d45 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import warnings from typing import AsyncGenerator, Literal @@ -14,14 +15,20 @@ ) from groq.types.chat.chat_completion_chunk import ChatCompletionChunk from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from groq.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) from groq.types.chat.chat_completion_system_message_param import ( ChatCompletionSystemMessageParam, ) +from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam from groq.types.chat.chat_completion_user_message_param import ( ChatCompletionUserMessageParam, ) - from groq.types.chat.completion_create_params import CompletionCreateParams +from groq.types.shared.function_definition import FunctionDefinition + +from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -32,6 +39,11 @@ CompletionMessage, Message, StopReason, + ToolCall, + ToolCallDelta, + ToolCallParseStatus, + ToolDefinition, + ToolPromptFormat, ) @@ -59,8 +71,8 @@ def convert_chat_completion_request( # so we exclude it for now warnings.warn("repetition_penalty is not supported") - if request.tools: - warnings.warn("tools are not supported yet") + if request.tool_prompt_format != ToolPromptFormat.json: + warnings.warn("tool_prompt_format is not used by Groq. Ignoring.") return CompletionCreateParams( model=request.model, @@ -71,6 +83,8 @@ def convert_chat_completion_request( max_tokens=request.sampling_params.max_tokens or None, temperature=request.sampling_params.temperature, top_p=request.sampling_params.top_p, + tools=[_convert_groq_tool_definition(tool) for tool in request.tools or []], + tool_choice=request.tool_choice.value if request.tool_choice else None, ) @@ -87,17 +101,64 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam: raise ValueError(f"Invalid message role: {message.role}") +def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict: + # Groq requires a description for function tools + if tool_definition.description is None: + raise AssertionError("tool_definition.description is required") + + tool_parameters = tool_definition.parameters or {} + return ChatCompletionToolParam( + type="function", + function=FunctionDefinition( + name=tool_definition.tool_name, + description=tool_definition.description, + parameters={ + key: _convert_groq_tool_parameter(param) + for key, param in tool_parameters.items() + }, + ), + ) + + +def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict: + param = { + "type": tool_parameter.param_type, + } + if tool_parameter.description is not None: + param["description"] = tool_parameter.description + if tool_parameter.required is not None: + param["required"] = tool_parameter.required + if tool_parameter.default is not None: + param["default"] = tool_parameter.default + return param + + def convert_chat_completion_response( response: ChatCompletion, ) -> ChatCompletionResponse: # groq only supports n=1 at time of writing, so there is only one choice choice = response.choices[0] - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=choice.message.content, - stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), - ), - ) + if choice.finish_reason == "tool_calls": + tool_calls = [ + _convert_groq_tool_call(tool_call) + for tool_call in choice.message.tool_calls + ] + return ChatCompletionResponse( + completion_message=CompletionMessage( + tool_calls=tool_calls, + stop_reason=StopReason.end_of_message, + # Content is not optional + content="", + ), + logprobs=None, + ) + else: + return ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content, + stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), + ), + ) def _map_finish_reason_to_stop_reason( @@ -116,7 +177,7 @@ def _map_finish_reason_to_stop_reason( elif finish_reason == "length": return StopReason.out_of_tokens elif finish_reason == "tool_calls": - raise NotImplementedError("tool_calls is not supported yet") + return StopReason.end_of_message else: raise ValueError(f"Invalid finish reason: {finish_reason}") @@ -129,25 +190,50 @@ async def convert_chat_completion_response_stream( for chunk in stream: choice = chunk.choices[0] - # We assume there's only one finish_reason for the entire stream. - # We collect the last finish_reason if choice.finish_reason: - stop_reason = _map_finish_reason_to_stop_reason(choice.finish_reason) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=event_type, - delta=choice.delta.content or "", - logprobs=None, + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=choice.delta.content or "", + logprobs=None, + stop_reason=_map_finish_reason_to_stop_reason(choice.finish_reason), + ) + ) + elif choice.delta.tool_calls: + # We assume there is only one tool call per chunk, but emit a warning in case we're wrong + if len(choice.delta.tool_calls) > 1: + warnings.warn( + "Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest." + ) + + # We assume Groq produces fully formed tool calls for each chunk + tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0]) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + ) + ) + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=choice.delta.content or "", + logprobs=None, + ) ) - ) event_type = ChatCompletionResponseEventType.progress - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - logprobs=None, - stop_reason=stop_reason, - ) + +def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall: + return ToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + # Note that Groq may return a string that is not valid JSON here + # So this may raise a 500 error. Going to leave this as is to see + # how big of an issue this is and what we can do about it. + arguments=json.loads(tool_call.function.arguments), ) diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index 53b5c29cb0..f3f263cb1e 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -4,21 +4,33 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json + import pytest from groq.types.chat.chat_completion import ChatCompletion, Choice from groq.types.chat.chat_completion_chunk import ( ChatCompletionChunk, Choice as StreamChoice, ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, ) from groq.types.chat.chat_completion_message import ChatCompletionMessage - +from groq.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) +from groq.types.shared.function_definition import FunctionDefinition +from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponseEventType, CompletionMessage, StopReason, SystemMessage, + ToolCall, + ToolChoice, + ToolDefinition, UserMessage, ) from llama_stack.providers.remote.inference.groq.groq_utils import ( @@ -140,12 +152,6 @@ def test_includes_max_tokens_if_set(self): assert converted["max_tokens"] == 100 - def _dummy_chat_completion_request(self): - return ChatCompletionRequest( - model="Llama-3.2-3B", - messages=[UserMessage(content="Hello World")], - ) - def test_includes_temperature(self): request = self._dummy_chat_completion_request() request.sampling_params.temperature = 0.5 @@ -162,6 +168,112 @@ def test_includes_top_p(self): assert converted["top_p"] == 0.95 + def test_includes_tool_choice(self): + request = self._dummy_chat_completion_request() + request.tool_choice = ToolChoice.required + + converted = convert_chat_completion_request(request) + + assert converted["tool_choice"] == "required" + + def test_includes_tools(self): + request = self._dummy_chat_completion_request() + request.tools = [ + ToolDefinition( + tool_name="get_flight_info", + description="Get fight information between two destinations.", + parameters={ + "origin": ToolParamDefinition( + param_type="string", + description="The origin airport code. E.g., AU", + required=True, + ), + "destination": ToolParamDefinition( + param_type="string", + description="The destination airport code. E.g., 'LAX'", + required=True, + ), + "passengers": ToolParamDefinition( + param_type="array", + description="The passengers", + required=False, + ), + }, + ), + ToolDefinition( + tool_name="log", + description="Calulate the logarithm of a number", + parameters={ + "number": ToolParamDefinition( + param_type="float", + description="The number to calculate the logarithm of", + required=True, + ), + "base": ToolParamDefinition( + param_type="integer", + description="The base of the logarithm", + required=False, + default=10, + ), + }, + ), + ] + + converted = convert_chat_completion_request(request) + + assert converted["tools"] == [ + { + "type": "function", + "function": FunctionDefinition( + name="get_flight_info", + description="Get fight information between two destinations.", + parameters={ + "origin": { + "type": "string", + "description": "The origin airport code. E.g., AU", + "required": True, + }, + "destination": { + "type": "string", + "description": "The destination airport code. E.g., 'LAX'", + "required": True, + }, + "passengers": { + "type": "array", + "description": "The passengers", + "required": False, + }, + }, + ), + }, + { + "type": "function", + "function": FunctionDefinition( + name="log", + description="Calulate the logarithm of a number", + parameters={ + "number": { + "type": "float", + "description": "The number to calculate the logarithm of", + "required": True, + }, + "base": { + "type": "integer", + "description": "The base of the logarithm", + "required": False, + "default": 10, + }, + }, + ), + }, + ] + + def _dummy_chat_completion_request(self): + return ChatCompletionRequest( + model="Llama-3.2-3B", + messages=[UserMessage(content="Hello World")], + ) + class TestConvertNonStreamChatCompletionResponse: def test_returns_response(self): @@ -188,6 +300,49 @@ def test_maps_length_to_end_of_message(self): assert converted.completion_message.stop_reason == StopReason.out_of_tokens + def test_maps_tool_call_to_end_of_message(self): + response = self._dummy_chat_completion_response_with_tool_call() + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.stop_reason == StopReason.end_of_message + + def test_converts_multiple_tool_calls(self): + response = self._dummy_chat_completion_response_with_tool_call() + response.choices[0].message.tool_calls = [ + ChatCompletionMessageToolCall( + id="tool_call_id", + type="function", + function=Function( + name="get_flight_info", + arguments='{"origin": "AU", "destination": "LAX"}', + ), + ), + ChatCompletionMessageToolCall( + id="tool_call_id_2", + type="function", + function=Function( + name="log", + arguments='{"number": 10, "base": 2}', + ), + ), + ] + + converted = convert_chat_completion_response(response) + + assert converted.completion_message.tool_calls == [ + ToolCall( + call_id="tool_call_id", + tool_name="get_flight_info", + arguments={"origin": "AU", "destination": "LAX"}, + ), + ToolCall( + call_id="tool_call_id_2", + tool_name="log", + arguments={"number": 10, "base": 2}, + ), + ] + def _dummy_chat_completion_response(self): return ChatCompletion( id="chatcmpl-123", @@ -205,6 +360,33 @@ def _dummy_chat_completion_response(self): object="chat.completion", ) + def _dummy_chat_completion_response_with_tool_call(self): + return ChatCompletion( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + Choice( + index=0, + message=ChatCompletionMessage( + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="tool_call_id", + type="function", + function=Function( + name="get_flight_info", + arguments='{"origin": "AU", "destination": "LAX"}', + ), + ) + ], + ), + finish_reason="tool_calls", + ) + ], + created=1729382400, + object="chat.completion", + ) + class TestConvertStreamChatCompletionResponse: @pytest.mark.asyncio @@ -214,10 +396,6 @@ def chat_completion_stream(): for i, message in enumerate(messages): chunk = self._dummy_chat_completion_chunk() chunk.choices[0].delta.content = message - if i == len(messages) - 1: - chunk.choices[0].finish_reason = "stop" - else: - chunk.choices[0].finish_reason = None yield chunk chunk = self._dummy_chat_completion_chunk() @@ -241,12 +419,6 @@ def chat_completion_stream(): assert chunk.event.event_type == ChatCompletionResponseEventType.progress assert chunk.event.delta == " !" - # Dummy chunk to ensure the last chunk is really the end of the stream - # This one technically maps to Groq's final "stop" chunk - chunk = await iter.__anext__() - assert chunk.event.event_type == ChatCompletionResponseEventType.progress - assert chunk.event.delta == "" - chunk = await iter.__anext__() assert chunk.event.event_type == ChatCompletionResponseEventType.complete assert chunk.event.delta == "" @@ -255,6 +427,53 @@ def chat_completion_stream(): with pytest.raises(StopAsyncIteration): await iter.__anext__() + @pytest.mark.asyncio + async def test_returns_tool_calls_stream(self): + def tool_call_stream(): + tool_calls = [ + ToolCall( + call_id="tool_call_id", + tool_name="get_flight_info", + arguments={"origin": "AU", "destination": "LAX"}, + ), + ToolCall( + call_id="tool_call_id_2", + tool_name="log", + arguments={"number": 10, "base": 2}, + ), + ] + for i, tool_call in enumerate(tool_calls): + chunk = self._dummy_chat_completion_chunk_with_tool_call() + chunk.choices[0].delta.tool_calls = [ + ChoiceDeltaToolCall( + index=0, + type="function", + id=tool_call.call_id, + function=ChoiceDeltaToolCallFunction( + name=tool_call.tool_name, + arguments=json.dumps(tool_call.arguments), + ), + ), + ] + yield chunk + + chunk = self._dummy_chat_completion_chunk_with_tool_call() + chunk.choices[0].delta.content = None + chunk.choices[0].finish_reason = "stop" + yield chunk + + stream = tool_call_stream() + converted = convert_chat_completion_response_stream(stream) + + iter = converted.__aiter__() + chunk = await iter.__anext__() + assert chunk.event.event_type == ChatCompletionResponseEventType.start + assert chunk.event.delta.content == ToolCall( + call_id="tool_call_id", + tool_name="get_flight_info", + arguments={"origin": "AU", "destination": "LAX"}, + ) + def _dummy_chat_completion_chunk(self): return ChatCompletionChunk( id="chatcmpl-123", @@ -269,3 +488,31 @@ def _dummy_chat_completion_chunk(self): object="chat.completion.chunk", x_groq=None, ) + + def _dummy_chat_completion_chunk_with_tool_call(self): + return ChatCompletionChunk( + id="chatcmpl-123", + model="Llama-3.2-3B", + choices=[ + StreamChoice( + index=0, + delta=ChoiceDelta( + role="assistant", + content="Hello World", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + type="function", + function=ChoiceDeltaToolCallFunction( + name="get_flight_info", + arguments='{"origin": "AU", "destination": "LAX"}', + ), + ) + ], + ), + ) + ], + created=1729382400, + object="chat.completion.chunk", + x_groq=None, + ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 02851830b8..414d0261b5 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -352,13 +352,13 @@ async def test_chat_completion_with_tool_calling( ): inference_impl, _ = inference_stack provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type in ("remote::groq",): - pytest.skip( - provider.__provider_spec__.provider_type - + " doesn't support tool calling yet" - ) + if ( + provider.__provider_spec__.provider_type == "remote::groq" + and "Llama-3.2" in inference_model + ): + # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better + pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") - inference_impl, _ = inference_stack messages = sample_messages + [ UserMessage( content="What's the weather like in San Francisco?", @@ -399,11 +399,12 @@ async def test_chat_completion_with_tool_calling_streaming( ): inference_impl, _ = inference_stack provider = inference_impl.routing_table.get_provider_impl(inference_model) - if provider.__provider_spec__.provider_type in ("remote::groq",): - pytest.skip( - provider.__provider_spec__.provider_type - + " doesn't support tool calling yet" - ) + if ( + provider.__provider_spec__.provider_type == "remote::groq" + and "Llama-3.2" in inference_model + ): + # TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better + pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well") messages = sample_messages + [ UserMessage( @@ -421,7 +422,6 @@ async def test_chat_completion_with_tool_calling_streaming( **common_params, ) ] - assert len(response) > 0 assert all( isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response