Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FallbackModel support #894

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d66b3a2
fallback proof of concept
sydney-runkle Feb 11, 2025
9956c96
remove model name updates
sydney-runkle Feb 11, 2025
a821380
catching 4xx and 5xx business
sydney-runkle Feb 11, 2025
a4b4ebb
fixing test
sydney-runkle Feb 11, 2025
b4a6b1a
move groq import
sydney-runkle Feb 11, 2025
4e2b089
return after yield
sydney-runkle Feb 11, 2025
9e7f53b
intro docs
sydney-runkle Feb 11, 2025
737cf2e
non streaming testing
sydney-runkle Feb 11, 2025
3113ae5
initial tests
sydney-runkle Feb 11, 2025
bf477e2
Comprehension golfing, thanks @alexmojaki
sydney-runkle Feb 12, 2025
47b4901
fix linting issue
sydney-runkle Feb 13, 2025
3b95211
openai test with exceptions
sydney-runkle Feb 13, 2025
992d540
Merge branch 'main' into fallback-model-updated
sydney-runkle Feb 13, 2025
25afb86
adding model_name and system abstract methods
sydney-runkle Feb 13, 2025
7a23eb1
Merge branch 'main' into fallback-model-updated
sydney-runkle Feb 13, 2025
31150ff
using sequence and fixing 3.9 tests
sydney-runkle Feb 13, 2025
b83bc5e
using list with type hinting
sydney-runkle Feb 13, 2025
154d907
tests
sydney-runkle Feb 13, 2025
0f49228
streaming tests
sydney-runkle Feb 13, 2025
7cb8267
Get tests passing
dmontagu Feb 13, 2025
ff7b596
Minor cleanup
dmontagu Feb 13, 2025
68b0f0b
type alias testing cleanup
sydney-runkle Feb 14, 2025
d614df2
exception group like
sydney-runkle Feb 14, 2025
9ac1f8e
adding fallback failure example
sydney-runkle Feb 14, 2025
82b6580
fix f string issue
sydney-runkle Feb 14, 2025
538ac12
try moving type alias definitions into protected import block
sydney-runkle Feb 14, 2025
ba5f778
docs updates + fixing 3.9 tests
sydney-runkle Feb 14, 2025
b11d971
Merge branch 'main' into fallback-model-updated
dmontagu Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api/models/fallback.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# pydantic_ai.models.fallback

::: pydantic_ai.models.fallback
92 changes: 92 additions & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,95 @@ For streaming, you'll also need to implement the following abstract base class:
The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py).

For details on when we'll accept contributions adding new models to PydanticAI, see the [contributing guidelines](contributing.md#new-model-rules).


## Fallback Models

You can use [`FallbackModel`][pydantic_ai.models.fallback.FallbackModel] to try a list of models in sequence until one returns a successful result. Under the hood, PydanticAI automatically switches from one model to the next if a 4xx or 5xx status code is returned by the current model.

In the following example, the agent will first make a request to the OpenAI model (which fails with an invalid API key), and then fall back to the Anthropic model. The `ModelResponse` message indicates that the
result was returned by the Anthropic model, which is the second model provided to the `FallbackModel`.

```python {title="fallback_model.py"}
from pydantic_ai import Agent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.openai import OpenAIModel

openai_model = OpenAIModel('gpt-4o', api_key='not-valid')
anthropic_model = AnthropicModel('claude-3-5-sonnet-latest')
fallback_model = FallbackModel(openai_model, anthropic_model)

agent = Agent(fallback_model)
response = agent.run_sync('What is the capital of France?')
print(response.data)
#> Paris

print(response.all_messages())
"""
[
ModelRequest(
parts=[
UserPromptPart(
content='What is the capital of France?',
timestamp=datetime.datetime(...),
part_kind='user-prompt',
)
],
kind='request',
),
ModelResponse(
parts=[TextPart(content='Paris', part_kind='text')],
model_name='function:model_logic',
timestamp=datetime.datetime(...),
kind='response',
),
]
"""
```

!!! note
You should configure each of model options individually, e.g. `base_url`, `api_key`, custom clients, etc. for each model should be set on the model itself, not the `FallbackModel`.

In this example, we demonstrate the exception handling capabilities of `FallbackModel`. If all models fail,
a [`FallbackModelFailure`][pydantic_ai.exceptions.FallbackModelFailure] exception is raised, which contains a
list of all [`ModelStatusError`][pydantic_ai.exceptions.ModelStatusError] exceptions raised during the run execution.

```python {title="fallback_model_failure.py"}
from pydantic_ai import Agent, FallbackModelFailure
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.openai import OpenAIModel

openai_model = OpenAIModel('gpt-4o', api_key='not-valid')
anthropic_model = AnthropicModel('claude-3-5-sonnet-latest', api_key='not-valid')
fallback_model = FallbackModel(openai_model, anthropic_model)

agent = Agent(fallback_model)
try:
response = agent.run_sync('What is the capital of France?')
except FallbackModelFailure as exc_info:
print(exc_info)
"""
FallbackModelFailure caused by:
- ModelStatusError:
status_code: 401
model_name: gpt-4o
body: {
'message': 'Incorrect API key provided: not-valid. You can find your API key at https://platform.openai.com/account/api-keys.',
'type': 'invalid_request_error',
'param': None,
'code': 'invalid_api_key'
}
- ModelStatusError:
status_code: 401
model_name: claude-3-5-sonnet-latest
body: {
'type': 'error',
'error': {
'type': 'authentication_error',
'message': 'invalid x-api-key'
}
}
"""
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ nav:
- api/models/mistral.md
- api/models/test.md
- api/models/function.md
- api/models/fallback.md
- api/pydantic_graph/graph.md
- api/pydantic_graph/nodes.md
- api/pydantic_graph/state.md
Expand Down
12 changes: 11 additions & 1 deletion pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from importlib.metadata import version

from .agent import Agent, capture_run_messages
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
from .exceptions import (
AgentRunError,
FallbackModelFailure,
ModelRetry,
ModelStatusError,
UnexpectedModelBehavior,
UsageLimitExceeded,
UserError,
)
from .tools import RunContext, Tool

__all__ = (
Expand All @@ -10,7 +18,9 @@
'RunContext',
'Tool',
'AgentRunError',
'FallbackModelFailure',
'ModelRetry',
'ModelStatusError',
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'UserError',
Expand Down
46 changes: 45 additions & 1 deletion pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

import json

__all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
__all__ = (
'ModelRetry',
'UserError',
'AgentRunError',
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'ModelStatusError',
'FallbackModelFailure',
)


class ModelRetry(Exception):
Expand Down Expand Up @@ -72,3 +80,39 @@ def __str__(self) -> str:
return f'{self.message}, body:\n{self.body}'
else:
return self.message


class ModelStatusError(AgentRunError):
"""Raised when an model provider response has a status code of 4xx or 5xx."""

status_code: int
"""The HTTP status code returned by the API."""

model_name: str
"""The name of the model associated with the error."""

body: object | None
"""The body of the response, if available."""

message: str
"""The error message with the status code and response body, if available."""

def __init__(self, status_code: int, model_name: str, body: object | None = None):
self.status_code = status_code
self.model_name = model_name
self.body = body
message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
super().__init__(message)


class FallbackModelFailure(AgentRunError):
"""Raised when all models in a `FallbackModel` fail."""

errors: list[ModelStatusError]
"""The collection of model status errors that ultimately caused the fallback to fail."""

def __init__(self, errors: list[ModelStatusError]):
self.errors = errors
stringified_errors = '\n'.join(f'{type(e).__name__}: {e}' for e in errors)
message = f'\nFallbackModelFailure caused by:\n{stringified_errors}'
super().__init__(message)
35 changes: 20 additions & 15 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

from .. import UnexpectedModelBehavior, _utils, usage
from .. import ModelStatusError, UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ModelMessage,
Expand All @@ -36,7 +36,7 @@
)

try:
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropic, AsyncStream
from anthropic.types import (
Message as AnthropicMessage,
MessageParam,
Expand Down Expand Up @@ -216,19 +216,24 @@ async def _messages_create(

system_prompt, anthropic_messages = self._map_message(messages)

return await self.client.messages.create(
max_tokens=model_settings.get('max_tokens', 1024),
system=system_prompt or NOT_GIVEN,
messages=anthropic_messages,
model=self._model_name,
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
)
try:
return await self.client.messages.create(
max_tokens=model_settings.get('max_tokens', 1024),
system=system_prompt or NOT_GIVEN,
messages=anthropic_messages,
model=self._model_name,
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
)
except APIStatusError as e:
if (status_code := e.status_code) >= 400:
raise ModelStatusError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise

def _process_response(self, response: AnthropicMessage) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
Expand Down
30 changes: 18 additions & 12 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

from .. import result
from .. import ModelStatusError, result
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ModelMessage,
Expand Down Expand Up @@ -45,6 +45,7 @@
ToolV2Function,
UserChatMessageV2,
)
from cohere.core.api_error import ApiError
from cohere.v2.client import OMIT
except ImportError as _import_error:
raise ImportError(
Expand Down Expand Up @@ -154,17 +155,22 @@ async def _chat(
) -> ChatResponse:
tools = self._get_tools(model_request_parameters)
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
return await self.client.chat(
model=self._model_name,
messages=cohere_messages,
tools=tools or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
seed=model_settings.get('seed', OMIT),
presence_penalty=model_settings.get('presence_penalty', OMIT),
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
)
try:
return await self.client.chat(
model=self._model_name,
messages=cohere_messages,
tools=tools or OMIT,
max_tokens=model_settings.get('max_tokens', OMIT),
temperature=model_settings.get('temperature', OMIT),
p=model_settings.get('top_p', OMIT),
seed=model_settings.get('seed', OMIT),
presence_penalty=model_settings.get('presence_penalty', OMIT),
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
)
except ApiError as e:
if (status_code := e.status_code) and status_code >= 400:
raise ModelStatusError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise

def _process_response(self, response: ChatResponse) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
Expand Down
92 changes: 92 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations as _annotations

from collections.abc import AsyncIterator
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from ..exceptions import FallbackModelFailure, ModelStatusError
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model

if TYPE_CHECKING:
from ..messages import ModelMessage, ModelResponse
from ..settings import ModelSettings
from ..usage import Usage


@dataclass(init=False)
class FallbackModel(Model):
"""A model that uses one or more fallback models upon failure.

Apart from `__init__`, all methods are private or match those of the base class.
"""

models: list[Model]

_model_name: str = field(repr=False)

def __init__(
self,
default_model: Model | KnownModelName,
*fallback_models: Model | KnownModelName,
):
"""Initialize a fallback model instance.

Args:
default_model: The name or instance of the default model to use.
fallback_models: The names or instances of the fallback models to use upon failure.
"""
self.models = [infer_model(default_model), *[infer_model(m) for m in fallback_models]]
self._model_name = f'FallBackModel[{", ".join(model.model_name for model in self.models)}]'

async def request(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> tuple[ModelResponse, Usage]:
"""Try each model in sequence until one succeeds."""
errors: list[ModelStatusError] = []

for model in self.models:
try:
return await model.request(messages, model_settings, model_request_parameters)
except ModelStatusError as exc_info:
errors.append(exc_info)
continue

raise FallbackModelFailure(errors=errors)

@asynccontextmanager
async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[StreamedResponse]:
"""Try each model in sequence until one succeeds."""
errors: list[ModelStatusError] = []

for model in self.models:
async with AsyncExitStack() as stack:
try:
response = await stack.enter_async_context(
model.request_stream(messages, model_settings, model_request_parameters)
)
except ModelStatusError as exc_info:
errors.append(exc_info)
continue
yield response
return

raise FallbackModelFailure(errors=errors)

@property
def model_name(self) -> str:
"""The model name."""
return self._model_name

@property
def system(self) -> str | None:
"""The system / model provider, n/a for fallback models."""
return None
Loading