Skip to content

Add tool_choice to ModelSettings #825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 40 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3777d06
Update openai.py
webcoderz Jan 31, 2025
6746781
Update pyproject.toml
webcoderz Jan 31, 2025
7d2c2ef
adding to model settings , removing monkeypatch
webcoderz Jan 31, 2025
2567bbc
Update openai.py
webcoderz Jan 31, 2025
4bc72f9
Update settings.py
webcoderz Feb 4, 2025
ac10ee1
backing this out
webcoderz Feb 6, 2025
ce38756
Update pyproject.toml
webcoderz Feb 6, 2025
fe341b1
Merge branch 'main' into webcoderz-model-settings
webcoderz Feb 7, 2025
12015c9
Update groq.py
webcoderz Feb 9, 2025
6dd9987
removing fallback comment
webcoderz Feb 12, 2025
3250aff
adding as per reccomendation
webcoderz Feb 12, 2025
a0b7454
removing tool_choice from ModelSettings and placing in each individu…
webcoderz Feb 12, 2025
79acaf3
the conditional checking tool_choice was not evaluating when i added …
webcoderz Feb 12, 2025
67e6ac3
adding _get_tool_choice to groq,cohere, openai
webcoderz Feb 12, 2025
9aa905c
unsure if these are necessary since seem supported already in mistral…
webcoderz Feb 12, 2025
9349412
fixing tool_choice across all models
webcoderz Feb 21, 2025
1bd0cf3
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderz Feb 22, 2025
a03fcfe
moving to top level settings
webcoderz Feb 24, 2025
89946a9
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderz Feb 24, 2025
20d8c8c
Merge branch 'webcoderz-model-settings'
webcoderz Feb 24, 2025
396b89c
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderz Feb 25, 2025
61b360b
Merge branch 'webcoderz-model-settings'
webcoderz Feb 25, 2025
e6df5fb
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderz Feb 28, 2025
1e4b0c5
Merge branch 'pydantic:main' into webcoderz-model-settings
webcoderz Mar 5, 2025
1419b5a
Merge branch 'webcoderz-model-settings'
webcoderz Mar 5, 2025
8b6f102
fixing ChatCompletionNamedToolChoiceParam
webcoderz Mar 5, 2025
1af96f3
Merge branch 'webcoderz-model-settings'
webcoderz Mar 5, 2025
0f47da9
Update cohere.py
webcoderz Mar 5, 2025
a23c014
Update openai.py
webcoderz Mar 5, 2025
9cefe9e
Update openai.py
webcoderz Mar 5, 2025
445c7ba
Update openai.py
webcoderz Mar 5, 2025
6ac1857
Merge remote-tracking branch 'origin/main' into webcoderz/main
Kludex Mar 7, 2025
2e238a2
Refactor
Kludex Mar 7, 2025
a86f5ce
Add Anthropic
Kludex Mar 7, 2025
9159aeb
full implementation
Kludex Mar 7, 2025
5cb233b
merge
Kludex Mar 31, 2025
059f92e
Merge remote-tracking branch 'origin/main' into webcoderz/main
Kludex Apr 7, 2025
30560b5
Merge remote-tracking branch 'origin/main' into webcoderz/main
Kludex Apr 15, 2025
ae1ab1c
Make GeminiModelSettings total=False
Kludex Apr 15, 2025
32cf6a6
Check safety settings on gemini properly
Kludex Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
UserPromptPart,
)
from ..providers import Provider, infer_provider
from ..settings import ModelSettings
from ..settings import ForcedFunctionToolChoice, ModelSettings
from ..tools import ToolDefinition
from . import (
Model,
Expand Down Expand Up @@ -209,19 +209,7 @@ async def _messages_create(
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
# standalone function to make it easier to override
tools = self._get_tools(model_request_parameters)
tool_choice: ToolChoiceParam | None

if not tools:
tool_choice = None
else:
if not model_request_parameters.allow_text_output:
tool_choice = {'type': 'any'}
else:
tool_choice = {'type': 'auto'}

if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls

tool_choice = self._map_tool_choice(model_settings, model_request_parameters, tools)
system_prompt, anthropic_messages = await self._map_message(messages)

try:
Expand Down Expand Up @@ -281,6 +269,32 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
return tools

@staticmethod
def _map_tool_choice(
model_settings: AnthropicModelSettings,
model_request_parameters: ModelRequestParameters,
tools: list[ToolParam],
) -> ToolChoiceParam | None:
"""Determine the `tool_choice` setting for the model.

Anthropic only supports `'auto'`, `'any'`, `'none'`, and a named tool.
"""
tool_choice = model_settings.get('tool_choice', 'auto')
disable_parallel_tool_use = not model_settings.get('parallel_tool_calls', True)

if tool_choice == 'auto' and tools and not model_request_parameters.allow_text_output:
return {'type': 'any', 'disable_parallel_tool_use': disable_parallel_tool_use}
elif tool_choice == 'required':
return {'type': 'any', 'disable_parallel_tool_use': disable_parallel_tool_use}
elif tool_choice == 'auto':
return {'type': 'auto', 'disable_parallel_tool_use': disable_parallel_tool_use}
elif tool_choice == 'none':
return {'type': 'none'}
elif isinstance(tool_choice, ForcedFunctionToolChoice):
return {'type': 'tool', 'name': tool_choice.name, 'disable_parallel_tool_use': disable_parallel_tool_use}
else:
assert_never(tool_choice)

async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
system_prompt: str = ''
Expand Down
55 changes: 34 additions & 21 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai.settings import ModelSettings
from pydantic_ai.settings import ForcedFunctionToolChoice, ModelSettings
from pydantic_ai.tools import ToolDefinition

if TYPE_CHECKING:
Expand All @@ -54,7 +54,7 @@
PerformanceConfigurationTypeDef,
PromptVariableValuesTypeDef,
SystemContentBlockTypeDef,
ToolChoiceTypeDef,
ToolConfigurationTypeDef,
ToolTypeDef,
VideoBlockTypeDef,
)
Expand Down Expand Up @@ -275,36 +275,28 @@ async def _messages_create(
self,
messages: list[ModelMessage],
stream: Literal[True],
model_settings: BedrockModelSettings | None,
model_settings: BedrockModelSettings,
model_request_parameters: ModelRequestParameters,
) -> EventStream[ConverseStreamOutputTypeDef]:
pass
) -> EventStream[ConverseStreamOutputTypeDef]: ...

@overload
async def _messages_create(
self,
messages: list[ModelMessage],
stream: Literal[False],
model_settings: BedrockModelSettings | None,
model_settings: BedrockModelSettings,
model_request_parameters: ModelRequestParameters,
) -> ConverseResponseTypeDef:
pass
) -> ConverseResponseTypeDef: ...

async def _messages_create(
self,
messages: list[ModelMessage],
stream: bool,
model_settings: BedrockModelSettings | None,
model_settings: BedrockModelSettings,
model_request_parameters: ModelRequestParameters,
) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
tools = self._get_tools(model_request_parameters)
support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
if not tools or not support_tools_choice:
tool_choice: ToolChoiceTypeDef = {}
elif not model_request_parameters.allow_text_output:
tool_choice = {'any': {}}
else:
tool_choice = {'auto': {}}
tool_config = self._get_tool_config(model_settings, model_request_parameters, tools)

system_prompt, bedrock_messages = await self._map_messages(messages)
inference_config = self._map_inference_config(model_settings)
Expand All @@ -315,6 +307,8 @@ async def _messages_create(
'system': system_prompt,
'inferenceConfig': inference_config,
}
if tool_config:
params['toolConfig'] = tool_config

# Bedrock supports a set of specific extra parameters
if model_settings:
Expand All @@ -333,18 +327,37 @@ async def _messages_create(
if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
params['promptVariables'] = prompt_variables

if tools:
params['toolConfig'] = {'tools': tools}
if tool_choice:
params['toolConfig']['toolChoice'] = tool_choice

if stream:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
model_response = model_response['stream']
else:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
return model_response

def _get_tool_config(
self,
model_settings: BedrockModelSettings,
model_request_parameters: ModelRequestParameters,
tools: list[ToolTypeDef],
) -> ToolConfigurationTypeDef | None:
support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
tool_choice = model_settings.get('tool_choice', 'auto')

if not tools or not support_tools_choice:
return None
elif tool_choice == 'auto' and not model_request_parameters.allow_text_output:
return {'tools': tools, 'toolChoice': {'any': {}}}
elif tool_choice == 'auto':
return {'tools': tools, 'toolChoice': {'auto': {}}}
elif tool_choice == 'none':
return None
elif tool_choice == 'required':
return {'tools': tools, 'toolChoice': {'any': {}}}
elif isinstance(tool_choice, ForcedFunctionToolChoice):
return {'tools': tools, 'toolChoice': {'tool': {'name': tool_choice.name}}}
else:
assert_never(tool_choice)

@staticmethod
def _map_inference_config(
model_settings: ModelSettings | None,
Expand Down
38 changes: 32 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from typing_extensions import assert_never

from pydantic_ai.exceptions import UserError

from .. import ModelHTTPError, usage
from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
from ..messages import (
Expand All @@ -22,13 +24,9 @@
UserPromptPart,
)
from ..providers import Provider, infer_provider
from ..settings import ModelSettings
from ..settings import ForcedFunctionToolChoice, ModelSettings
from ..tools import ToolDefinition
from . import (
Model,
ModelRequestParameters,
check_allow_model_requests,
)
from . import Model, ModelRequestParameters, check_allow_model_requests

try:
from cohere import (
Expand All @@ -44,6 +42,7 @@
ToolV2,
ToolV2Function,
UserChatMessageV2,
V2ChatRequestToolChoice,
)
from cohere.core.api_error import ApiError
from cohere.v2.client import OMIT
Expand Down Expand Up @@ -156,12 +155,14 @@ async def _chat(
model_request_parameters: ModelRequestParameters,
) -> ChatResponse:
tools = self._get_tools(model_request_parameters)
tool_choice = self._map_tool_choice(model_settings, model_request_parameters, tools)
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
try:
return await self.client.chat(
model=self._model_name,
messages=cohere_messages,
tools=tools or OMIT,
tool_choice=tool_choice or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
stop_sequences=model_settings.get('stop_sequences', OMIT),
temperature=model_settings.get('temperature', OMIT),
Expand Down Expand Up @@ -223,6 +224,31 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
return tools

@staticmethod
def _map_tool_choice(
model_settings: CohereModelSettings, model_request_parameters: ModelRequestParameters, tools: list[ToolV2]
) -> V2ChatRequestToolChoice | None:
"""Determine the `tool_choice` setting for the model.

Cohere only supports `'REQUIRED'` and `'NONE'` for tool choice.
See [Cohere's docs](https://docs.cohere.com/v2/docs/tool-use-usage-patterns#forcing-tool-usage) for more details.
"""
tool_choice = model_settings.get('tool_choice', 'auto')

if tool_choice == 'auto' and tools and not model_request_parameters.allow_text_output:
return 'REQUIRED'
elif tool_choice == 'auto':
return None
elif isinstance(tool_choice, ForcedFunctionToolChoice):
raise UserError(
'Cohere does not support forcing a specific tool. '
'Please choose a different value for the `tool_choice` parameter in the model settings.'
)
elif tool_choice in ('none', 'required'):
return tool_choice.upper()
else:
assert_never(tool_choice)

@staticmethod
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
return ToolCallV2(
Expand Down
55 changes: 39 additions & 16 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
UserPromptPart,
VideoUrl,
)
from ..settings import ModelSettings
from ..settings import ForcedFunctionToolChoice, ModelSettings
from ..tools import ToolDefinition
from . import (
Model,
Expand Down Expand Up @@ -72,7 +72,7 @@
"""


class GeminiModelSettings(ModelSettings):
class GeminiModelSettings(ModelSettings, total=False):
"""Settings used for a Gemini model request.

ALL FIELDS MUST BE `gemini_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
Expand Down Expand Up @@ -180,15 +180,35 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin
tools += [_function_from_abstract_tool(t) for t in model_request_parameters.output_tools]
return _GeminiTools(function_declarations=tools) if tools else None

@staticmethod
def _get_tool_config(
self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
model_settings: GeminiModelSettings,
model_request_parameters: ModelRequestParameters,
tools: _GeminiTools | None,
) -> _GeminiToolConfig | None:
if model_request_parameters.allow_text_output:
"""Determine the `tool_choice` setting for the model.

AUTO: The default model behavior. The model decides to predict either a function call or a natural language response.
ANY: The model is constrained to always predict a function call. If allowed_function_names is not provided,
the model picks from all of the available function declarations. If allowed_function_names is provided,
the model picks from the set of allowed functions.
NONE: The model won't predict a function call. In this case, the model behavior is the same as if you don't
pass any function declarations.
"""
tool_choice = model_settings.get('tool_choice', 'auto')

if tool_choice == 'auto' and tools and not model_request_parameters.allow_text_output:
return {'function_calling_config': {'mode': 'ANY'}}
elif tool_choice == 'auto':
return None
elif tools:
return _tool_config([t['name'] for t in tools['function_declarations']])
elif tool_choice == 'none':
return {'function_calling_config': {'mode': 'NONE'}}
elif tool_choice == 'required':
return {'function_calling_config': {'mode': 'ANY'}}
elif isinstance(tool_choice, ForcedFunctionToolChoice):
return {'function_calling_config': {'mode': 'ANY', 'allowed_function_names': [tool_choice.name]}}
else:
return _tool_config([])
assert_never(tool_choice)

@asynccontextmanager
async def _make_request(
Expand All @@ -199,7 +219,7 @@ async def _make_request(
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[HTTPResponse]:
tools = self._get_tools(model_request_parameters)
tool_config = self._get_tool_config(model_request_parameters, tools)
tool_config = self._get_tool_config(model_settings, model_request_parameters, tools)
sys_prompt_parts, contents = await self._message_to_gemini_content(messages)

request_data = _GeminiRequest(contents=contents)
Expand All @@ -222,7 +242,7 @@ async def _make_request(
generation_config['presence_penalty'] = presence_penalty
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
generation_config['frequency_penalty'] = frequency_penalty
if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
if gemini_safety_settings := model_settings.get('gemini_safety_settings'):
request_data['safety_settings'] = gemini_safety_settings
if generation_config:
request_data['generation_config'] = generation_config
Expand Down Expand Up @@ -666,15 +686,18 @@ class _GeminiToolConfig(TypedDict):
function_calling_config: _GeminiFunctionCallingConfig


def _tool_config(function_names: list[str]) -> _GeminiToolConfig:
return _GeminiToolConfig(
function_calling_config=_GeminiFunctionCallingConfig(mode='ANY', allowed_function_names=function_names)
)
class _GeminiFunctionCallingConfig(TypedDict):
"""The function calling config for the Gemini API.

See <https://ai.google.dev/gemini-api/docs/function-calling>
"""

class _GeminiFunctionCallingConfig(TypedDict):
mode: Literal['ANY', 'AUTO']
allowed_function_names: list[str]
mode: Literal['ANY', 'AUTO', 'NONE']
allowed_function_names: NotRequired[list[str]]
"""If not provided, all functions are allowed.

It can only be used with `mode` set to `'ANY'`.
"""


@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
Expand Down
Loading
Loading