Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FallbackModel support #894

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d66b3a2
fallback proof of concept
sydney-runkle Feb 11, 2025
9956c96
remove model name updates
sydney-runkle Feb 11, 2025
a821380
catching 4xx and 5xx business
sydney-runkle Feb 11, 2025
a4b4ebb
fixing test
sydney-runkle Feb 11, 2025
b4a6b1a
move groq import
sydney-runkle Feb 11, 2025
4e2b089
return after yield
sydney-runkle Feb 11, 2025
9e7f53b
intro docs
sydney-runkle Feb 11, 2025
737cf2e
non streaming testing
sydney-runkle Feb 11, 2025
3113ae5
initial tests
sydney-runkle Feb 11, 2025
bf477e2
Comprehension golfing, thanks @alexmojaki
sydney-runkle Feb 12, 2025
47b4901
fix linting issue
sydney-runkle Feb 13, 2025
3b95211
openai test with exceptions
sydney-runkle Feb 13, 2025
992d540
Merge branch 'main' into fallback-model-updated
sydney-runkle Feb 13, 2025
25afb86
adding model_name and system abstract methods
sydney-runkle Feb 13, 2025
7a23eb1
Merge branch 'main' into fallback-model-updated
sydney-runkle Feb 13, 2025
31150ff
using sequence and fixing 3.9 tests
sydney-runkle Feb 13, 2025
b83bc5e
using list with type hinting
sydney-runkle Feb 13, 2025
154d907
tests
sydney-runkle Feb 13, 2025
0f49228
streaming tests
sydney-runkle Feb 13, 2025
7cb8267
Get tests passing
dmontagu Feb 13, 2025
ff7b596
Minor cleanup
dmontagu Feb 13, 2025
68b0f0b
type alias testing cleanup
sydney-runkle Feb 14, 2025
d614df2
exception group like
sydney-runkle Feb 14, 2025
9ac1f8e
adding fallback failure example
sydney-runkle Feb 14, 2025
82b6580
fix f string issue
sydney-runkle Feb 14, 2025
538ac12
try moving type alias definitions into protected import block
sydney-runkle Feb 14, 2025
ba5f778
docs updates + fixing 3.9 tests
sydney-runkle Feb 14, 2025
b11d971
Merge branch 'main' into fallback-model-updated
dmontagu Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api/models/fallback.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# pydantic_ai.models.fallback

::: pydantic_ai.models.fallback
49 changes: 49 additions & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -635,3 +635,52 @@ For streaming, you'll also need to implement the following abstract base class:
The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py).

For details on when we'll accept contributions adding new models to PydanticAI, see the [contributing guidelines](contributing.md#new-model-rules).


## Fallback Models

You can use [`FallbackModel`][pydantic_ai.models.fallback.FallbackModel] to try a list of models in sequence until one returns a successful result. Under the hood, PydanticAI automatically switches from one model to the next if a 4xx or 5xx status code is returned by the current model.

In the following example, the agent will first make a request to the OpenAI model (which fails with an invalid API key), and then fall back to the Anthropic model. The `ModelResponse` message indicates that the
result was returned by the Anthropic model, which is the second model provided to the `FallbackModel`.

```python {title="fallback_model.py"}
from pydantic_ai import Agent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.openai import OpenAIModel

openai_model = OpenAIModel('gpt-4o', api_key='not-valid')
anthropic_model = AnthropicModel('claude-3-5-sonnet-latest')
fallback_model = FallbackModel(openai_model, anthropic_model)

agent = Agent(fallback_model)
response = agent.run_sync('What is the capital of France?')
print(response.data)
#> Paris

print(response.all_messages())
"""
[
ModelRequest(
parts=[
UserPromptPart(
content='What is the capital of France?',
timestamp=datetime.datetime(...),
part_kind='user-prompt',
)
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='Paris', part_kind='text')],
model_name='function:model_logic',
timestamp=datetime.datetime(...),
kind='response',
),
]
"""
```

!!! note
You should configure each of model options individually, e.g. `base_url`, `api_key`, custom clients, etc. for each model should be set on the model itself, not the `FallbackModel`.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ nav:
- api/models/mistral.md
- api/models/test.md
- api/models/function.md
- api/models/fallback.md
- api/pydantic_graph/graph.md
- api/pydantic_graph/nodes.md
- api/pydantic_graph/state.md
Expand Down
10 changes: 9 additions & 1 deletion pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from importlib.metadata import version

from .agent import Agent, capture_run_messages
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
from .exceptions import (
AgentRunError,
ModelRetry,
ModelStatusError,
UnexpectedModelBehavior,
UsageLimitExceeded,
UserError,
)
from .tools import RunContext, Tool

__all__ = (
Expand All @@ -11,6 +18,7 @@
'Tool',
'AgentRunError',
'ModelRetry',
'ModelStatusError',
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'UserError',
Expand Down
25 changes: 25 additions & 0 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,28 @@ def __str__(self) -> str:
return f'{self.message}, body:\n{self.body}'
else:
return self.message


class ModelStatusError(AgentRunError):
"""Raised when an model provider response has a status code of 4xx or 5xx."""

status_code: int
"""The HTTP status code returned by the API."""

model_name: str
"""The name of the model associated with the error."""

body: object | None
"""The body of the response, if available."""

message: str
"""The error message with the status code and response body, if available."""

def __init__(self, status_code: int, model_name: str, body: object | None = None):
self.status_code = status_code
self.body = body
if body is None:
message = f'status_code: {status_code}, model_name: {model_name}'
else:
message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
super().__init__(message)
35 changes: 20 additions & 15 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

from .. import UnexpectedModelBehavior, _utils, usage
from .. import ModelStatusError, UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ModelMessage,
Expand All @@ -36,7 +36,7 @@
)

try:
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
from anthropic.types import (
Message as AnthropicMessage,
MessageParam,
Expand Down Expand Up @@ -206,19 +206,24 @@ async def _messages_create(

system_prompt, anthropic_messages = self._map_message(messages)

return await self.client.messages.create(
max_tokens=model_settings.get('max_tokens', 1024),
system=system_prompt or NOT_GIVEN,
messages=anthropic_messages,
model=self._model_name,
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
)
try:
return await self.client.messages.create(
max_tokens=model_settings.get('max_tokens', 1024),
system=system_prompt or NOT_GIVEN,
messages=anthropic_messages,
model=self._model_name,
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
)
except APIStatusError as e:
if (status_code := e.status_code) >= 400:
raise ModelStatusError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise

def _process_response(self, response: AnthropicMessage) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
Expand Down
30 changes: 18 additions & 12 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

from .. import result
from .. import ModelStatusError, result
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ModelMessage,
Expand Down Expand Up @@ -45,6 +45,7 @@
ToolV2Function,
UserChatMessageV2,
)
from cohere.core.api_error import ApiError
from cohere.v2.client import OMIT
except ImportError as _import_error:
raise ImportError(
Expand Down Expand Up @@ -144,17 +145,22 @@ async def _chat(
) -> ChatResponse:
tools = self._get_tools(model_request_parameters)
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
return await self.client.chat(
model=self._model_name,
messages=cohere_messages,
tools=tools or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
seed=model_settings.get('seed', OMIT),
presence_penalty=model_settings.get('presence_penalty', OMIT),
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
)
try:
return await self.client.chat(
model=self._model_name,
messages=cohere_messages,
tools=tools or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
seed=model_settings.get('seed', OMIT),
presence_penalty=model_settings.get('presence_penalty', OMIT),
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
)
except ApiError as e:
if (status_code := e.status_code) and status_code >= 400:
raise ModelStatusError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise

def _process_response(self, response: ChatResponse) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
Expand Down
84 changes: 84 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations as _annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from ..exceptions import ModelStatusError
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model

if TYPE_CHECKING:
from ..messages import ModelMessage, ModelResponse
from ..settings import ModelSettings
from ..usage import Usage


@dataclass(init=False)
class FallbackModel(Model):
"""A model that uses one or more fallback models upon failure.

Apart from `__init__`, all methods are private or match those of the base class.
"""

models: list[Model]

_model_name: str = field(repr=False)
_system: str | None = field(default=None, repr=False)

def __init__(
self,
default_model: Model | KnownModelName,
*fallback_models: Model | KnownModelName,
):
"""Initialize a fallback model instance.

Args:
default_model: The name or instance of the default model to use.
fallback_models: The names or instances of the fallback models to use upon failure.
"""
default_model_ = default_model if isinstance(default_model, Model) else infer_model(default_model)
fallback_models_ = [model if isinstance(model, Model) else infer_model(model) for model in fallback_models]

self.models = [default_model_, *fallback_models_]

self._model_name = f'FallBackModel[{", ".join(model.model_name for model in self.models)}]'

async def request(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> tuple[ModelResponse, Usage]:
"""Try each model in sequence until one succeeds."""
errors: list[ModelStatusError] = []

for model in self.models:
try:
return await model.request(messages, model_settings, model_request_parameters)
except ModelStatusError as exc_info:
errors.append(exc_info)
continue

raise RuntimeError(f'All fallback models failed: {errors}')

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[StreamedResponse]:
"""Try each model in sequence until one succeeds."""
errors: list[ModelStatusError] = []

for model in self.models:
try:
async with model.request_stream(messages, model_settings, model_request_parameters) as response:
yield response
return
except ModelStatusError as exc_info:
errors.append(exc_info)
continue

raise RuntimeError(f'All fallback models failed: {errors}')
16 changes: 9 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
from typing_extensions import NotRequired, TypedDict, assert_never

from .. import UnexpectedModelBehavior, _utils, exceptions, usage
from .. import ModelStatusError, UnexpectedModelBehavior, UserError, _utils, usage
from ..messages import (
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(
if env_api_key := os.getenv('GEMINI_API_KEY'):
api_key = env_api_key
else:
raise exceptions.UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
self.http_client = http_client or cached_async_http_client()
self._auth = ApiKeyAuth(api_key)
self._url = url_template.format(model=model_name)
Expand Down Expand Up @@ -219,9 +219,11 @@ async def _make_request(
headers=headers,
timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
) as r:
if r.status_code != 200:
if (status_code := r.status_code) != 200:
await r.aread()
raise exceptions.UnexpectedModelBehavior(f'Unexpected response from gemini {r.status_code}', r.text)
if status_code >= 400:
raise ModelStatusError(status_code=status_code, model_name=self.model_name, body=r.text)
raise UnexpectedModelBehavior(f'Unexpected response from gemini {status_code}', r.text)
yield r

def _process_response(self, response: _GeminiResponse) -> ModelResponse:
Expand Down Expand Up @@ -499,7 +501,7 @@ def _process_response_from_parts(
)
)
elif 'function_response' in part:
raise exceptions.UnexpectedModelBehavior(
raise UnexpectedModelBehavior(
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
)
return ModelResponse(parts=items, model_name=model_name, timestamp=timestamp or _utils.now_utc())
Expand Down Expand Up @@ -707,7 +709,7 @@ def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None
# noinspection PyTypeChecker
key = re.sub(r'^#/\$defs/', '', ref)
if key in refs_stack:
raise exceptions.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
raise UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
refs_stack += (key,)
schema_def = self.defs[key]
self._simplify(schema_def, refs_stack)
Expand Down Expand Up @@ -741,7 +743,7 @@ def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
ad_props = schema.pop('additionalProperties', None)
if ad_props:
raise exceptions.UserError('Additional properties in JSON Schema are not supported by Gemini')
raise UserError('Additional properties in JSON Schema are not supported by Gemini')

if properties := schema.get('properties'): # pragma: no branch
for value in properties.values():
Expand Down
Loading
Loading