Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#432] Add Groq Provider - tool calls #630

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

aidando73
Copy link
Contributor

@aidando73 aidando73 commented Dec 14, 2024

What does this PR do?

Contributes to issue #432

  • Adds tool calls to Groq provider
  • Enables tool call integration tests

PR Train

Test Plan

Environment:

export GROQ_API_KEY=<api-key>

# build.yaml and run.yaml files
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/build.yaml
wget https://raw.githubusercontent.com/aidando73/llama-stack/9165502582cd7cb178bc1dcf89955b45768ab6c1/run.yaml

# Create environment if not already
conda create --prefix ./envs python=3.10
conda activate ./envs

# Build and run
pip install -e . \
&& llama stack build --config ./build.yaml --image-type conda \
&& llama stack run ./run.yaml \
  --port 5001
Unit tests: pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py -vv -k groq -s
llama_stack/providers/tests/inference/groq/test_groq_utils.py .....................

======================================== 21 passed, 1 warning in 0.05s ========================================
Integration tests: pytest llama_stack/providers/tests/inference/test_text_inference.py -k groq -s
llama_stack/providers/tests/inference/test_text_inference.py .sss.s.ss.sss.s...

========================== 8 passed, 10 skipped, 180 deselected, 7 warnings in 2.73s ==========================
Manual

Via this Jupyter notebook: https://github.com/aidando73/llama-stack/blob/9165502582cd7cb178bc1dcf89955b45768ab6c1/hello.ipynb

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Ran pre-commit to handle lint / formatting issues.
  • Read the contributor guideline,
    Pull Request section?
  • Updated relevant documentation. (no relevant documentation it seems)
  • Wrote necessary unit or integration tests.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 14, 2024
README.md Show resolved Hide resolved
@aidando73 aidando73 changed the title [#432] Add tool calls to groq inference adapter [#432] Add Groq Provider - tool calls Dec 15, 2024
except groq.BadRequestError as e:
if e.body.get("error", {}).get("code") == "tool_use_failed":
# For smaller models, Groq may fail to call a tool even when the request is well formed
raise ValueError("Groq failed to call a tool", e.body.get("error", {}))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find a better Error class for this. RequestValidationError could work [1] - but that's a fast api error. Not sure if it's a good idea to have fast api errors in adapter code.

So just going to use ValueError for now. Maybe we could use a set of Llama-stack specific error classes? E.g. llama_stack.BadRequestError, llama_stack.ServiceUnavailableError.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aidando73 Does groq support the raw completions API? If it does that, I'd rather always use that instead of the chat completions API. Llama stack can then format the tool formats exactly as needed and re-parse the tool calls from groq.

Copy link
Contributor Author

@aidando73 aidando73 Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashwinb no unfortunately not - they only support the chat completions API. These are all their endpoints:

Screen.Recording.2024-12-21.at.19.30.55.mov

Here's their client sdk: https://github.com/groq/groq-python/blob/main/api.md#chat

@ricklamers do you know if there's any plans to support the completions API?

elif choice.delta.tool_calls:
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
if len(choice.delta.tool_calls) > 1:
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No explicit documentation by Groq on this, but based on my testing this seems to be the case.

For multiple tool calls, each one is a separate chunk:

Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=[ChoiceDeltaToolCall(index=0, id='call_gjk5', function=ChoiceDeltaToolCallFunction(arguments='{"message": 10}', name='calculate_log'), type='function')]), finish_reason=None, index=0, logprobs=None)
Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=[ChoiceDeltaToolCall(index=1, id='call_qw70', function=ChoiceDeltaToolCallFunction(arguments='{"a": 3.52, "b": 4.89}', name='add'), type='function')]), finish_reason=None, index=0, logprobs=None)
Choice(delta=ChoiceDelta(content=None, function_call=None, role=None, tool_calls=[ChoiceDeltaToolCall(index=2, id='call_f64g', function=ChoiceDeltaToolCallFunction(arguments='{"a": 3.52, "b": 4.89}', name='multiply'), type='function')]), finish_reason=None, index=0, logprobs=None)

if len(choice.delta.tool_calls) > 1:
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")

# We assume Groq produces fully formed tool calls for each chunk
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No explicit documentation by Groq on this, but based on my testing this seems to be the case.

E.g., I made it run a tool call with a 10k character document and Grok puts the whole tool call into one chunk:

ChatCompletionResponseStreamChunk(event=ChatCompletionResponseStreamChunkEvent(delta=ChatCompletionResponseStreamChunkEventDeltaToolCallDelta(content=ToolCall(arguments={'text': 'Donec eu lorem eget quam accumsan iaculis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Aenean felis tortor, tincidunt eu purus at, lacinia mollis mi. Aliquam lacinia molestie augue ac vestibulum. Duis ante lacus, vulputate a sollicitudin in, consectetur fermentum augue. Maecenas eget risus a dolor mattis feugiat. Integer accumsan tempor elit vel imperdiet. Donec et dignissim velit. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris dictum, nibh id varius accumsan, massa risus aliquam diam, at pulvinar risus risus sit amet erat. In nec lorem metus. Nunc molestie mollis enim, vitae volutpat elit blandit non.Donec eu lorem eget quam accumsan iaculis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Aenean felis tortor, tincidunt eu purus at, lacinia mollis mi. Aliquam lacinia molestie augue ac vestibulum. Duis ante lacus, vulputate a sollicitudin in, consectetur fermentum augue. Maecenas eget risus a dolor mattis feugiat. Integer accumsan tempor elit vel imperdiet. Donec et dignissim velit. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris dictum, nibh id varius accumsan, massa risus aliquam diam, at pulvinar risus risus sit amet erat. In nec lorem metus. Nunc molestie mollis enim, vitae volutpat elit blandit non.Donec eu lorem eget quam accumsan iaculis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Aenean felis tortor, tincidunt eu purus at, lacinia mollis mi. Aliquam lacinia molestie augue ac vestibulum. Duis ante lacus, vulputate a sollicitudin in, consectetur fermentum augue. Maecenas eget risus a dolor mattis feugiat. Integer accumsan tempor elit vel imperdiet. Donec et dignissim velit. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris dictum, nibh id varius accumsan, massa risus aliquam diam, at pulvinar risus risus sit amet erat. In nec lorem metus. Nunc molestie mollis enim, vitae volutpat elit blandit non.Donec eu lorem eget quam accumsan iaculis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Aenean felis tortor, tincidunt eu purus at, lacinia mollis mi. Aliquam lacinia molestie augue ac vestibulum. Duis ante lacus, vulputate a sollicitudin in, consectetur fermentum augue. Maecenas eget risus a dolor mattis feugiat. Integer accumsan tempor elit vel imperdiet. Donec et dignissim velit. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris dictum, nibh id varius accumsan, massa risus aliquam diam, at pulvinar risus risus sit amet erat. In nec lorem metus. Nunc molestie mollis enim, vitae volutpat elit blandit non.Donec eu lorem eget quam accumsan iaculis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Aenean felis tortor, tincidunt eu purus at, lacinia mollis mi. Aliquam lacinia molestie augue ac vestibulum. Duis ante lacus, vulputate a sollicitudin in, consectetur fermentum augue. Maecenas eget risus a dolor mattis feugiat. Integer accumsan tempor elit vel imperdiet. Donec et dignissim velit. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris dictum, nibh id varius accumsan, massa risus aliquam diam, at pulvinar risus risus sit amet erat. In nec lorem metus. Nunc molestie mollis enim, vitae volutpat elit blandit non.Donec eu lorem eget quam accumsan iaculis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Aenean felis tortor, tincidunt eu purus at, lacinia mollis mi. Aliquam lacinia molestie augue ac vestibulum. Duis ante lacus, vulputate a sollicitudin in, consectetur fermentum augue. Maecenas eget risus a dolor mattis feugiat. Integer accumsan tempor elit vel imperdiet. Donec et dignissim velit. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris dictum, nibh id varius accumsan, massa risus aliquam diam, at pulvinar risus risus sit amet erat. In nec lorem metus. Nunc molestie mollis enim, vitae volutpat elit blandit non.Donec eu lorem eget quam accumsan iaculis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Aenean felis tortor, tincidunt eu purus at, lacinia mollis mi. Aliquam lacinia molestie augue ac vestibulum. Duis ante lacus, vulputate a sollicitudin in, consectetur fermentum augue. Maecenas eget risus a dolor mattis feugiat. Integer accumsan tempor elit vel imperdiet. Donec et dignissim velit. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris dictum, nibh id varius accumsan, massa risus aliquam diam, at pulvinar risus risus sit amet erat. In nec lorem metus. Nunc molestie mollis enim, vitae volutpat elit blandit non.Donec eu lorem eget quam accumsan iaculis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Aenean felis tortor, tincidunt eu purus at, lacinia mollis mi. Aliquam lacinia molestie augue ac vestibulum. Duis ante lacus, vulputate a sollicitudin in, consectetur fermentum augue. Maecenas eget risus a dolor mattis feugiat. Integer accumsan tempor elit vel imperdiet. Donec et dignissim velit. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris dictum, nibh id varius accumsan, massa risus aliquam diam, at pulvinar risus risus sit amet erat. In nec lorem metus. Nunc molestie mollis enim, vitae volutpat elit blandit non.Donec eu lorem eget quam accumsan iaculis. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Aenean felis tortor, tincidunt eu purus at, lacinia mollis mi. Aliquam lacinia molestie augue ac vestibulum. Duis ante lacus, vulputate a sollicitudin in, consectetur fermentum augue. Maecenas eget risus a dolor mattis feugiat. Integer accumsan tempor elit vel imperdiet. Donec et dignissim velit. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Mauris dictum, nibh id varius accumsan, massa risus aliquam diam, at pulvinar risus risus sit amet erat. In nec lorem metus. Nunc molestie mollis enim, vitae volutpat elit blandit non.'}, call_id='call_676h', tool_name='count_characters'), parse_status='in_progress'), event_type='progress', logprobs=None, stop_reason=None))

# Note that Groq may return a string that is not valid JSON here
# So this may raise a 500 error. Going to leave this as is to see
# how big of an issue this is and what we can do about it.
arguments=json.loads(tool_call.function.arguments),
Copy link
Contributor Author

@aidando73 aidando73 Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lmk if you want me to handle invalid JSON here. If so, lmk which approach you'd prefer:

  • Return an empty arguments dict
  • Stringify the entire tool call as a string

After that, I'm assuming we'd want to set:

parse_status=ToolCallParseStatus.failure,

and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 3.2-3B, Groq doesn't parse the tool call properly:

CompletionMessage(role='assistant', content='<function=get_weather>{"location": "San Francisco, CA"}', stop_reason=<StopReason.end_of_turn: 'end_of_turn'>, tool_calls=[])

and "Llama-3.2" in inference_model
):
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
Copy link
Contributor Author

@aidando73 aidando73 Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 3.2-3B, Groq returns an error:

ChatCompletionResponseStreamChunk(event=ChatCompletionResponseEvent(event_type=<ChatCompletionResponseEventType.complete: 'complete'>, delta='Tool use failed: JSON does not match the expected schema for tool calls', logprobs=None, stop_reason=<StopReason.end_of_turn: 'end_of_turn'>))

stop_reason=stop_reason,
)
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
return ToolCall(
Copy link
Contributor Author

@aidando73 aidando73 Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note unrelated to this PR: ToolCall fails for more complex nested types:

E.g.,

# For a param definition:
"passengers": {
    "param_type": "array",
    "description": "The passengers",
},

# Groq wants to return something like:
res = [
  {'name': 'John', 'age': 35, 'class': 'Economy'},
  {'name': 'Jane', 'age': 32, 'class': 'Economy'},
  {'name': 'Tim', 'age': 5, 'class': 'Economy'}
]

# But we run into error:
pydantic_core._pydantic_core.ValidationError: 17 validation errors for ToolCall
arguments.passengers.str
  Input should be a valid string [type=string_type, input_value=[{'name': 'John', 'age': ... 5, 'class': 'Economy'}], input_type=list]
    For further information visit https://errors.pydantic.dev/2.10/v/string_type
arguments.passengers.int
  Input should be a valid integer [type=int_type, input_value=[{'name': 'John', 'age': ... 5, 'class': 'Economy'}], input_type=list]
    For further information visit https://errors.pydantic.dev/2.10/v/int_type
arguments.passengers.float
  Input should be a valid number [type=float_type, input_value=[{'name': 'John', 'age': ... 5, 'class': 'Economy'}], input_type=list]
    For further information visit https://errors.pydantic.dev/2.10/v/float_type
arguments.passengers.bool
  Input should be a valid boolean [type=bool_type, input_value=[{'name': 'John', 'age': ... 5, 'class': 'Economy'}], input_type=list]
    For further information visit https://errors.pydantic.dev/2.10/v/bool_type
arguments.passengers.list[nullable[union[str,int,float,bool]]].0.str
  Input should be a valid string [type=string_type, input_value={'name': 'John', 'age': 35, 'class': 'Economy'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/string_type
arguments.passengers.list[nullable[union[str,int,float,bool]]].0.int
  Input should be a valid integer [type=int_type, input_value={'name': 'John', 'age': 35, 'class': 'Economy'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/int_type
arguments.passengers.list[nullable[union[str,int,float,bool]]].0.float
  Input should be a valid number [type=float_type, input_value={'name': 'John', 'age': 35, 'class': 'Economy'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/float_type
arguments.passengers.list[nullable[union[str,int,float,bool]]].0.bool
  Input should be a valid boolean [type=bool_type, input_value={'name': 'John', 'age': 35, 'class': 'Economy'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/bool_type
...
Code
# Tool call with object and array parameters
response = client.inference.chat_completion(
    model_id="Llama3.1-8B-Instruct",
    messages=[
        {"role": "system", "content": "You are a helpful assistant helping users book flights and hotels."},
        {"role": "user", "content": """
         When's the next flight from Adelaide to Sydney? (I only want direct flights and the flight should have wifi and a meal.)
         The flight should fit 2 adults and 1 child. Economy class.
         Also find a hotel in the Sydney area for 3 nights starting on the 1st of January 2024. (Should be smoking friendly.)
         """},
    ],
    # stream=True,
    tools=[
        {
            "tool_name": "get_flight_info",
            "description": "Get the flight information for a given origin and destination",
            "parameters": {
                "origin": {
                    "param_type": "string",
                    "description": "The origin airport code. E.g., AU",
                    "required": True,
                },
                "destination": {
                    "param_type": "string",
                    "description": "The destination airport code. E.g., 'LAX'",
                    "required": True,
                },
                "passengers": {
                    "param_type": "array",
                    "description": "The passengers",
                },
            }
        },
        {
            "tool_name": "get_hotel_info",
            "description": "Get the hotel information for a given destination",
            "parameters": {
                "address": {
                    "param_type": "object",
                    "description": "The address of the hotel. E.g., {'street_address': '123 Main St', 'city': 'Sydney', 'state': 'NSW', 'post_code': '2000'}",
                    "properties": {
                        "street_address": {
                            "param_type": "string",
                            "description": "The street address of the hotel. E.g., '123 Main St'",
                            "required": True,
                        },
                        "city": {
                            "param_type": "string",
                            "description": "The city of the hotel. E.g., 'Sydney'",
                            "required": True,
                        },
                        "state": {
                            "param_type": "string",
                            "description": "The state of the hotel. E.g., 'NSW'",
                            "required": True,
                        },
                        "post_code": {
                            "param_type": "string",
                            "description": "The post code of the hotel. E.g., '2000'",
                            "required": True,
                        },
                    },
                    "required": True,
                },
                "num_nights": {
                    "param_type": "integer",
                    "description": "The number of nights to stay. E.g., 3",
                    "required": True,
                },
                "date_from": {
                    "param_type": "string",
                    "description": "The date to start the stay formatted as YYYY-MM-DD. E.g., '2024-01-01'",
                    "required": True,
                },
                "smoking_friendly": {
                    "param_type": "boolean",
                    "description": "Whether the hotel is smoking friendly. E.g., True",
                    "required": False,
                },
            }
        },
    ]
)

print(response)

for tool_call in response.completion_message.tool_calls:
    print(tool_call)

Going to ignore for now. Would like to see whether our users actually want this before implementing.

@@ -84,6 +84,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
Copy link
Contributor Author

@aidando73 aidando73 Dec 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📣 Reviewers - action required:

Don't review the current diff.

Please use review this diff:
https://github.com/meta-llama/llama-stack/pull/630/files/c0757fd16971cf0c3e5c8e52e511e3a90563bc64..HEAD 👈

Your comments will still appear on this PR

GitHub doesn't support PR trains to forks (the current diff contains previous PRs in the train)

@aidando73
Copy link
Contributor Author

I've rebased the PR train

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants