From 9ae0309f31bf15577f379214373b9d27c8b39c5e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 20 Dec 2024 17:42:48 +0000 Subject: [PATCH] add a default to `ResultData`, some related cleanup (#512) Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> --- docs/api/result.md | 5 ----- pydantic_ai_examples/sql_gen.py | 2 +- pydantic_ai_slim/pydantic_ai/_result.py | 4 ++-- pydantic_ai_slim/pydantic_ai/result.py | 25 +++++++++++++++++++++---- pydantic_ai_slim/pydantic_ai/tools.py | 16 ---------------- tests/typed_agent.py | 19 ++++++++++++++++++- 6 files changed, 42 insertions(+), 29 deletions(-) diff --git a/docs/api/result.md b/docs/api/result.md index d4310881..c22a52e2 100644 --- a/docs/api/result.md +++ b/docs/api/result.md @@ -3,8 +3,3 @@ ::: pydantic_ai.result options: inherited_members: true - members: - - ResultData - - RunResult - - StreamedRunResult - - Usage diff --git a/pydantic_ai_examples/sql_gen.py b/pydantic_ai_examples/sql_gen.py index 0d23b5e7..f636fffe 100644 --- a/pydantic_ai_examples/sql_gen.py +++ b/pydantic_ai_examples/sql_gen.py @@ -73,7 +73,7 @@ class InvalidRequest(BaseModel): Response: TypeAlias = Union[Success, InvalidRequest] -agent = Agent( +agent: Agent[Deps, Response] = Agent( 'gemini-1.5-flash', # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else result_type=Response, # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py index 8a54ed7b..5ecb127d 100644 --- a/pydantic_ai_slim/pydantic_ai/_result.py +++ b/pydantic_ai_slim/pydantic_ai/_result.py @@ -12,8 +12,8 @@ from . import _utils, messages as _messages from .exceptions import ModelRetry -from .result import ResultData -from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition +from .result import ResultData, ResultValidatorFunc +from .tools import AgentDeps, RunContext, ToolDefinition @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index bfe0b54f..9cba5f0a 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -4,9 +4,10 @@ from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime -from typing import Generic, TypeVar, cast +from typing import Generic, Union, cast import logfire_api +from typing_extensions import TypeVar from . import _result, _utils, exceptions, messages as _messages, models from .settings import UsageLimits @@ -14,21 +15,37 @@ __all__ = ( 'ResultData', + 'ResultValidatorFunc', 'Usage', 'RunResult', 'StreamedRunResult', ) -ResultData = TypeVar('ResultData') +ResultData = TypeVar('ResultData', default=str) """Type variable for the result data of a run.""" +ResultValidatorFunc = Union[ + Callable[[RunContext[AgentDeps], ResultData], ResultData], + Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], + Callable[[ResultData], ResultData], + Callable[[ResultData], Awaitable[ResultData]], +] +""" +A function that always takes `ResultData` and returns `ResultData` and: + +* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument +* may or may not be async + +Usage `ResultValidatorFunc[AgentDeps, ResultData]`. +""" + _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') @dataclass class Usage: - """LLM usage associated to a request or run. + """LLM usage associated with a request or run. Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests. @@ -36,7 +53,7 @@ class Usage: """ requests: int = 0 - """Number of requests made.""" + """Number of requests made to the LLM API.""" request_tokens: int | None = None """Tokens used in processing requests.""" response_tokens: int | None = None diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 262e1f80..72a66f5b 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -16,7 +16,6 @@ __all__ = ( 'AgentDeps', 'RunContext', - 'ResultValidatorFunc', 'SystemPromptFunc', 'ToolFuncContext', 'ToolFuncPlain', @@ -73,21 +72,6 @@ def replace_with( Usage `SystemPromptFunc[AgentDeps]`. """ -ResultData = TypeVar('ResultData') - -ResultValidatorFunc = Union[ - Callable[[RunContext[AgentDeps], ResultData], ResultData], - Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], - Callable[[ResultData], ResultData], - Callable[[ResultData], Awaitable[ResultData]], -] -""" -A function that always takes `ResultData` and returns `ResultData`, -but may or maybe not take `CallInfo` as a first argument, and may or may not be async. - -Usage `ResultValidator[AgentDeps, ResultData]`. -""" - ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any] """A tool function that takes `RunContext` as the first argument. diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 8cb7ebec..38710377 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Iterator from contextlib import contextmanager from dataclasses import dataclass -from typing import Callable, Union, assert_type +from typing import Callable, TypeAlias, Union, assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool from pydantic_ai.result import RunResult @@ -178,6 +178,13 @@ def run_sync3() -> None: assert_type(result.data, Union[Foo, Bar]) +MyUnion: TypeAlias = 'Foo | Bar' +union_agent2: Agent[None, MyUnion] = Agent( + result_type=MyUnion, # type: ignore[arg-type] +) +assert_type(union_agent2, Agent[None, MyUnion]) + + def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> str: return f'{x} {y}' @@ -225,3 +232,13 @@ async def prepare_greet(ctx: RunContext[str], tool_def: ToolDefinition) -> ToolD result = greet_agent.run_sync('testing...', deps='human') assert result.data == '{"greet":"hello a"}' + +MYPY = False +if not MYPY: + default_agent = Agent() + assert_type(default_agent, Agent[None, str]) + assert_type(default_agent, Agent[None]) + +partial_agent: Agent[MyDeps] = Agent(deps_type=MyDeps) +assert_type(partial_agent, Agent[MyDeps, str]) +assert_type(partial_agent, Agent[MyDeps])