Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding tool_choice to ModelSettings #825

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from json import JSONDecodeError, loads as json_loads
from typing import Any, Literal, Union, cast, overload


from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

Expand Down
31 changes: 30 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterable
from dataclasses import dataclass, field
from itertools import chain
from typing import Literal, Union, cast
from typing import Literal, Union, Any, cast

from cohere import TextAssistantMessageContentItem
from httpx import AsyncClient as AsyncHTTPClient
Expand Down Expand Up @@ -71,10 +71,15 @@

CohereModelName = Union[NamedCohereModels, str]

V2ChatRequestToolChoice = Union[Literal["REQUIRED", "NONE"], Any]

class CohereModelSettings(ModelSettings):
"""Settings used for a Cohere model request."""

tool_choice: V2ChatRequestToolChoice
"""Whether to require a specific tool to be used."""


# This class is a placeholder for any future cohere-specific settings


Expand Down Expand Up @@ -166,6 +171,29 @@ async def request(
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
return self._process_response(response), _map_usage(response)

def _get_tool_choice(self, model_settings: CohereModelSettings) -> V2ChatRequestToolChoice | None:
"""Determine the tool_choice setting for the model.

Allowed values in model_settings:
- 'REQUIRED': The model must use at least one tool.
- 'NONE': The model is forced not to use a tool.
If not provided, the model is free to choose:
- If no tools are available, leave unspecified.
- If text responses are disallowed, force tool usage ('REQUIRED').
- If text responses are allowed, leave unspecified (free to choose).
"""
tool_choice: V2ChatRequestToolChoice | None = getattr(model_settings, 'tool_choice', None)

if tool_choice is None:
if not self.tools:
tool_choice = None
elif not self.allow_text_result:
tool_choice = 'REQUIRED'
else:
tool_choice = None

return tool_choice

async def _chat(
self,
messages: list[ModelMessage],
Expand All @@ -176,6 +204,7 @@ async def _chat(
model=self.model_name,
messages=cohere_messages,
tools=self.tools or OMIT,
tool_choice=self._get_tool_choice(model_settings) or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list.
"""

FunctionCallConfigMode = Literal["ANY", "NONE", "AUTO"]

class GeminiModelSettings(ModelSettings):
"""Settings used for a Gemini model request."""

tool_choice: FunctionCallConfigMode
# This class is a placeholder for any future gemini-specific settings


Expand Down
40 changes: 31 additions & 9 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from itertools import chain
from typing import Literal, cast, overload
from typing import Literal, Dict, Any, cast, overload

from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never
from typing_extensions import TypedDict, assert_never

from .. import UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
Expand Down Expand Up @@ -63,10 +63,18 @@
See [the Groq docs](https://console.groq.com/docs/models) for a full list.
"""

class ChatCompletionNamedToolChoiceParam(TypedDict):
type: Literal["named"]
name: str
parameters: Dict[str, Any]

class GroqModelSettings(ModelSettings):
"""Settings used for a Groq model request."""

tool_choice: Literal['none', 'required', 'auto']
"""Whether to require a specific tool to be used."""


# This class is a placeholder for any future groq-specific settings


Expand Down Expand Up @@ -180,16 +188,30 @@ async def _completions_create(
) -> chat.ChatCompletion:
pass

def _get_tool_choice(self, model_settings: GroqModelSettings) -> Literal['none', 'required', 'auto'] | None:
"""Get tool choice for the model.

- "auto": Default mode. Model decides if it uses the tool or not.
- "none": Prevents tool use.
- "required": Forces tool use.
"""
tool_choice: Literal['none', 'required', 'auto'] | None = getattr(model_settings, 'tool_choice', None)

if tool_choice is None:
if not self.tools:
tool_choice = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
tool_choice = 'auto'

return tool_choice

async def _completions_create(
self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
# standalone function to make it easier to override
if not self.tools:
tool_choice: Literal['none', 'required', 'auto'] | None = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
tool_choice = 'auto'


groq_messages = list(chain(*(self._map_message(m) for m in messages)))

Expand All @@ -199,7 +221,7 @@ async def _completions_create(
n=1,
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
tools=self.tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
tool_choice=self._get_tool_choice(model_settings) or NOT_GIVEN,
stream=stream,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
temperature=model_settings.get('temperature', NOT_GIVEN),
Expand Down
47 changes: 35 additions & 12 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from itertools import chain
from typing import Literal, Union, cast, overload

from typing import Literal, Union, cast, overload, Any, Dict
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never
from typing_extensions import TypedDict, assert_never

from .. import UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
Expand Down Expand Up @@ -53,10 +52,21 @@

OpenAISystemPromptRole = Literal['system', 'developer', 'user']

class ChatCompletionNamedToolChoiceParam(TypedDict):
type: Literal["named"]
name: str
parameters: Dict[str, Any]


class OpenAIModelSettings(ModelSettings):
"""Settings used for an OpenAI model request."""

tool_choice: Union[
Literal["none", "auto", "required"],
ChatCompletionNamedToolChoiceParam
]
"""Whether to require a specific tool to be used."""

# This class is a placeholder for any future openai-specific settings


Expand Down Expand Up @@ -182,17 +192,30 @@ async def _completions_create(
) -> chat.ChatCompletion:
pass

def _get_tool_choice(self, model_settings: OpenAIModelSettings) -> Literal['none', 'required', 'auto'] | None:
"""Get tool choice for the model.

- "auto": Default mode. Model decides if it uses the tool or not.
- "none": Prevents tool use.
- "required": Forces tool use.
"""
tool_choice: Literal['none', 'required', 'auto'] | None = getattr(model_settings, 'tool_choice', None)

if tool_choice is None:
if not self.tools:
tool_choice = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
tool_choice = 'auto'

return tool_choice


async def _completions_create(
self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
# standalone function to make it easier to override
if not self.tools:
tool_choice: Literal['none', 'required', 'auto'] | None = None
elif not self.allow_text_result:
tool_choice = 'required'
else:
tool_choice = 'auto'


openai_messages = list(chain(*(self._map_message(m) for m in messages)))

return await self.client.chat.completions.create(
Expand All @@ -201,7 +224,7 @@ async def _completions_create(
n=1,
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
tools=self.tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
tool_choice=self._get_tool_choice(model_settings) or NOT_GIVEN,
stream=stream,
stream_options={'include_usage': True} if stream else NOT_GIVEN,
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ class ModelSettings(TypedDict, total=False):
"""








def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
"""Merge two sets of model settings, preferring the overrides.

Expand Down