diff --git a/Makefile b/Makefile index 7b12e2a9f..4c6c75a4e 100644 --- a/Makefile +++ b/Makefile @@ -64,7 +64,7 @@ testcov: test ## Run tests and generate a coverage report .PHONY: update-examples update-examples: ## Update documentation examples - uv run -m pytest --update-examples + uv run -m pytest --update-examples tests/test_examples.py # `--no-strict` so you can build the docs without insiders packages .PHONY: docs diff --git a/docs/agents.md b/docs/agents.md index 19da9f3a2..818e2db88 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -62,13 +62,14 @@ print(result.data) ## Running Agents -There are three ways to run an agent: +There are four ways to run an agent: -1. [`agent.run()`][pydantic_ai.Agent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.result.RunResult] containing a completed response -2. [`agent.run_sync()`][pydantic_ai.Agent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.result.RunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`) -3. [`agent.run_stream()`][pydantic_ai.Agent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable +1. [`agent.run()`][pydantic_ai.Agent.run] — a coroutine which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response. +2. [`agent.run_sync()`][pydantic_ai.Agent.run_sync] — a plain, synchronous function which returns a [`RunResult`][pydantic_ai.agent.AgentRunResult] containing a completed response (internally, this just calls `loop.run_until_complete(self.run())`). +3. [`agent.run_stream()`][pydantic_ai.Agent.run_stream] — a coroutine which returns a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], which contains methods to stream a response as an async iterable. +4. [`agent.iter()`][pydantic_ai.Agent.iter] — a context manager which returns an [`AgentRun`][pydantic_ai.agent.AgentRun], an async-iterable over the nodes of the agent's underlying [`Graph`][pydantic_graph.graph.Graph]. -Here's a simple example demonstrating all three: +Here's a simple example demonstrating the first three: ```python {title="run_agent.py"} from pydantic_ai import Agent @@ -94,6 +95,131 @@ _(This example is complete, it can be run "as is" — you'll need to add `asynci You can also pass messages from previous runs to continue a conversation or provide context, as described in [Messages and Chat History](message-history.md). +### Iterating Over an Agent's Graph + +Under the hood, each `Agent` in PydanticAI uses **pydantic-graph** to manage its execution flow. **pydantic-graph** is a generic, type-centric library for building and running finite state machines in Python. It doesn't actually depend on PydanticAI — you can use it standalone for workflows that have nothing to do with GenAI — but PydanticAI makes use of it to orchestrate the handling of model requests and model responses in an agent's run. + +In many scenarios, you don't need to worry about pydantic-graph at all; calling `agent.run(...)` simply traverses the underlying graph from start to finish. However, if you need deeper insight or control — for example to capture each tool invocation, or to inject your own logic at specific stages — PydanticAI exposes the lower-level iteration process via [`Agent.iter`][pydantic_ai.Agent.iter]. This method returns an [`AgentRun`][pydantic_ai.agent.AgentRun], which you can async-iterate over, or manually drive node-by-node via the [`next`][pydantic_ai.agent.AgentRun.next] method. Once the agent's graph returns an [`End`][pydantic_graph.nodes.End], you have the final result along with a detailed history of all steps. + +#### `async for` iteration + +Here's an example of using `async for` with `iter` to record each node the agent executes: + +```python {title="agent_iter_async_for.py"} +from pydantic_ai import Agent + +agent = Agent('openai:gpt-4o') + + +async def main(): + nodes = [] + # Begin an AgentRun, which is an async-iterable over the nodes of the agent's graph + with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + # Each node represents a step in the agent's execution + nodes.append(node) + print(nodes) + """ + [ + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', + ) + ), + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', + ) + ), + End(data=FinalResult(data='Paris', tool_name=None)), + ] + """ + print(agent_run.result.data) + #> Paris +``` + +- The `AgentRun` is an async iterator that yields each node (`BaseNode` or `End`) in the flow. +- The run ends when an `End` node is returned. + +#### Using `.next(...)` manually + +You can also drive the iteration manually by passing the node you want to run next to the `AgentRun.next(...)` method. This allows you to inspect or modify the node before it executes or skip nodes based on your own logic, and to catch errors in `next()` more easily: + +```python {title="agent_iter_next.py"} +from pydantic_ai import Agent +from pydantic_graph import End + +agent = Agent('openai:gpt-4o') + + +async def main(): + with agent.iter('What is the capital of France?') as agent_run: + node = agent_run.next_node # (1)! + + all_nodes = [node] + + # Drive the iteration manually: + while not isinstance(node, End): # (2)! + node = await agent_run.next(node) # (3)! + all_nodes.append(node) # (4)! + + print(all_nodes) + """ + [ + UserPromptNode( + user_prompt='What is the capital of France?', + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', + ) + ), + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', + ) + ), + End(data=FinalResult(data='Paris', tool_name=None)), + ] + """ +``` + +1. We start by grabbing the first node that will be run in the agent's graph. +2. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. +3. When you call `await agent_run.next(node)`, it executes that node in the agent's graph, updates the run's history, and returns the *next* node to run. +4. You could also inspect or mutate the new `node` here as needed. + +#### Accessing usage and the final result + +You can retrieve usage statistics (tokens, requests, etc.) at any time from the [`AgentRun`][pydantic_ai.agent.AgentRun] object via `agent_run.usage()`. This method returns a [`Usage`][pydantic_ai.usage.Usage] object containing the usage data. + +Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] object containing the final output (and related metadata). + +--- + ### Additional Configuration #### Usage Limits @@ -177,7 +303,7 @@ except UsageLimitExceeded as e: 2. This run will error after 3 requests, preventing the infinite tool calling. !!! note - This is especially relevant if you're registered a lot of tools, `request_limit` can be used to prevent the model from choosing to make too many of these calls. + This is especially relevant if you've registered many tools. The `request_limit` can be used to prevent the model from calling them in a loop too many times. #### Model (Run) Settings @@ -441,7 +567,7 @@ If models behave unexpectedly (e.g., the retry limit is exceeded, or their API r In these cases, [`capture_run_messages`][pydantic_ai.capture_run_messages] can be used to access the messages exchanged during the run to help diagnose the issue. -```python +```python {title="agent_model_errors.py"} from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, capture_run_messages agent = Agent('openai:gpt-4o') diff --git a/docs/api/agent.md b/docs/api/agent.md index 890c418ee..b26cfb58e 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -4,6 +4,8 @@ options: members: - Agent + - AgentRun + - AgentRunResult - EndStrategy - - RunResultData + - RunResultDataT - capture_run_messages diff --git a/docs/api/result.md b/docs/api/result.md index c22a52e24..8e6cef79e 100644 --- a/docs/api/result.md +++ b/docs/api/result.md @@ -2,4 +2,7 @@ ::: pydantic_ai.result options: - inherited_members: true + inherited_members: true + members: + - ResultDataT + - StreamedRunResult diff --git a/docs/graph.md b/docs/graph.md index fa1b87343..95a6c0637 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -16,12 +16,12 @@ Graphs and finite state machines (FSMs) are a powerful abstraction to model, exe Alongside PydanticAI, we've developed `pydantic-graph` — an async graph and state machine library for Python where nodes and edges are defined using type hints. -While this library is developed as part of PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. +While this library is developed as part of PydanticAI; it has no dependency on `pydantic-ai` and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. -`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and types hints. It is not designed to be as beginner-friendly as PydanticAI. +`pydantic-graph` is designed for advanced users and makes heavy use of Python generics and type hints. It is not designed to be as beginner-friendly as PydanticAI. !!! note "Very Early beta" - Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in very earlier beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. + Graph support was [introduced](https://github.com/pydantic/pydantic-ai/pull/528) in v0.0.19 and is in a very early beta. The API is subject to change. The documentation is incomplete. The implementation is incomplete. ## Installation @@ -33,7 +33,7 @@ pip/uv-add pydantic-graph ## Graph Types -`pydantic-graph` made up of a few key components: +`pydantic-graph` is made up of a few key components: ### GraphRunContext @@ -156,18 +156,18 @@ class Increment(BaseNode): # (2)! fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! -result, history = fives_graph.run_sync(DivisibleBy5(4)) # (4)! -print(result) +result = fives_graph.run_sync(DivisibleBy5(4)) # (4)! +print(result.output) #> 5 # the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in history]) +print([item.data_snapshot() for item in result.history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` 1. The `DivisibleBy5` node is parameterized with `None` for the state param and `None` for the deps param as this graph doesn't use state or deps, and `int` as it can end the run. 2. The `Increment` node doesn't return `End`, so the `RunEndT` generic parameter is omitted, state can also be omitted as the graph doesn't use state. 3. The graph is created with a sequence of nodes. -4. The graph is run synchronously with [`run_sync`][pydantic_graph.graph.Graph.run_sync] the initial state `None` and the start node `DivisibleBy5(4)` are passed as arguments. +4. The graph is run synchronously with [`run_sync`][pydantic_graph.graph.Graph.run_sync]. The initial node is `DivisibleBy5(4)`. Because the graph doesn't use external state or deps, we don't pass `state` or `deps`. _(This example is complete, it can be run "as is" with Python 3.10+)_ @@ -295,17 +295,17 @@ async def main(): 2. A dictionary of products mapped to prices. 3. The `InsertCoin` node, [`BaseNode`][pydantic_graph.nodes.BaseNode] is parameterized with `MachineState` as that's the state used in this graph. 4. The `InsertCoin` node prompts the user to insert coins. We keep things simple by just entering a monetary amount as a float. Before you start thinking this is a toy too since it's using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] within nodes, see [below](#custom-control-flow) for how control flow can be managed when nodes require external input. -5. The `CoinsInserted` node; again this is a [`dataclass`][dataclasses.dataclass], in this case with one field `amount`, thus nodes calling `CoinsInserted` must provide an amount. +5. The `CoinsInserted` node; again this is a [`dataclass`][dataclasses.dataclass] with one field `amount`. 6. Update the user's balance with the amount inserted. 7. If the user has already selected a product, go to `Purchase`, otherwise go to `SelectProduct`. 8. In the `Purchase` node, look up the price of the product if the user entered a valid product. 9. If the user did enter a valid product, set the product in the state so we don't revisit `SelectProduct`. 10. If the balance is enough to purchase the product, adjust the balance to reflect the purchase and return [`End`][pydantic_graph.nodes.End] to end the graph. We're not using the run return type, so we call `End` with `None`. -11. If the balance is insufficient, to go `InsertCoin` to prompt the user to insert more coins. +11. If the balance is insufficient, go to `InsertCoin` to prompt the user to insert more coins. 12. If the product is invalid, go to `SelectProduct` to prompt the user to select a product again. -13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but will alter how [diagrams](#mermaid-diagrams) are displayed. +13. The graph is created by passing a list of nodes to [`Graph`][pydantic_graph.graph.Graph]. Order of nodes is not important, but it can affect how [diagrams](#mermaid-diagrams) are displayed. 14. Initialize the state. This will be passed to the graph run and mutated as the graph runs. -15. Run the graph with the initial state. Since the graph can be run from any node, we must pass the start node — in this case, `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a tuple of the return value (`None`) in this case, and the [history][pydantic_graph.state.HistoryStep] of the graph run. +15. Run the graph with the initial state. Since the graph can be run from any node, we must pass the start node — in this case, `InsertCoin`. [`Graph.run`][pydantic_graph.graph.Graph.run] returns a [`GraphRunResult`][pydantic_graph.graph.GraphRunResult] that provides the final data and a history of the run. 16. The return type of the node's [`run`][pydantic_graph.nodes.BaseNode.run] method is important as it is used to determine the outgoing edges of the node. This information in turn is used to render [mermaid diagrams](#mermaid-diagrams) and is enforced at runtime to detect misbehavior as soon as possible. 17. The return type of `CoinsInserted`'s [`run`][pydantic_graph.nodes.BaseNode.run] method is a union, meaning multiple outgoing edges are possible. 18. Unlike other nodes, `Purchase` can end the run, so the [`RunEndT`][pydantic_graph.nodes.RunEndT] generic parameter must be set. In this case it's `None` since the graph run return type is `None`. @@ -464,8 +464,8 @@ async def main(): ) state = State(user) feedback_graph = Graph(nodes=(WriteEmail, Feedback)) - email, _ = await feedback_graph.run(WriteEmail(), state=state) - print(email) + result = await feedback_graph.run(WriteEmail(), state=state) + print(result.output) """ Email( subject='Welcome to our tech blog!', @@ -606,6 +606,7 @@ async def main(): Ask(), Answer(question='what is 1 + 1?', answer='2'), Evaluate(answer='2'), + End(data='Well done, 1 + 1 = 2'), ] """ return @@ -642,11 +643,107 @@ stateDiagram-v2 Reprimand --> Ask ``` -You maybe have noticed that although this examples transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). +You maybe have noticed that although this example transfers control flow out of the graph run, we're still using [rich's `Prompt.ask`][rich.prompt.PromptBase.ask] to get user input, with the process hanging while we wait for the user to enter a response. For an example of genuine out-of-process control flow, see the [question graph example](examples/question-graph.md). + +## Iterating Over a Graph + +### Using `Graph.iter` for `async for` iteration + +Sometimes you want direct control or insight into each node as the graph executes. The easiest way to do that is with the [`Graph.iter`][pydantic_graph.graph.Graph.iter] method, which returns a **context manager** that yields a [`GraphRun`][pydantic_graph.graph.GraphRun] object. The `GraphRun` is an async-iterable over the nodes of your graph, allowing you to record or modify them as they execute. + +Here's an example: + +```python {title="count_down.py" noqa="I001" py="3.10"} +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from pydantic_graph import Graph, BaseNode, End, GraphRunContext + + +@dataclass +class CountDownState: + counter: int + + +@dataclass +class CountDown(BaseNode[CountDownState]): + async def run(self, ctx: GraphRunContext[CountDownState]) -> CountDown | End[int]: + if ctx.state.counter <= 0: + return End(ctx.state.counter) + ctx.state.counter -= 1 + return CountDown() + + +count_down_graph = Graph(nodes=[CountDown]) + + +async def main(): + state = CountDownState(counter=3) + with count_down_graph.iter(CountDown(), state=state) as run: # (1)! + async for node in run: # (2)! + print('Node:', node) + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + #> Node: End(data=0) + print('Final result:', run.result.output) # (3)! + #> Final result: 0 + print('History snapshots:', [step.data_snapshot() for step in run.history]) + """ + History snapshots: + [CountDown(), CountDown(), CountDown(), CountDown(), End(data=0)] + """ +``` + +1. `Graph.iter(...)` returns a [`GraphRun`][pydantic_graph.graph.GraphRun]. +2. Here, we step through each node as it is executed. +3. Once the graph returns an [`End`][pydantic_graph.nodes.End], the loop ends, and `run.final_result` becomes a [`GraphRunResult`][pydantic_graph.graph.GraphRunResult] containing the final outcome (`0` here). + +### Using `GraphRun.next(node)` manually + +Alternatively, you can drive iteration manually with the [`GraphRun.next`][pydantic_graph.graph.GraphRun.next] method, which allows you to pass in whichever node you want to run next. You can modify or selectively skip nodes this way. + +Below is a contrived example that stops whenever the counter is at 2, ignoring any node runs beyond that: + +```python {title="count_down_next.py" noqa="I001" py="3.10"} +from pydantic_graph import End +from count_down import CountDown, CountDownState, count_down_graph + + +async def main(): + state = CountDownState(counter=5) + with count_down_graph.iter(CountDown(), state=state) as run: + node = run.next_node # (1)! + while not isinstance(node, End): # (2)! + print('Node:', node) + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + #> Node: CountDown() + if state.counter == 2: + break # (3)! + node = await run.next(node) # (4)! + + print(run.result) # (5)! + #> None + + for step in run.history: # (6)! + print('History Step:', step.data_snapshot(), step.state) + #> History Step: CountDown() CountDownState(counter=4) + #> History Step: CountDown() CountDownState(counter=3) + #> History Step: CountDown() CountDownState(counter=2) +``` + +1. We start by grabbing the first node that will be run in the agent's graph. +2. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. +3. If the user decides to stop early, we break out of the loop. The graph run won't have a real final result in that case (`run.final_result` remains `None`). +4. At each step, we call `await run.next(node)` to run it and get the next node (or an `End`). +5. Because we did not continue the run until it finished, the `result` is not set. +6. The run's history is still populated with the steps we executed so far. ## Dependency Injection -As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphRunContext.deps`][pydantic_graph.nodes.GraphRunContext.deps] fields. +As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphRunContext.deps`][pydantic_graph.nodes.GraphRunContext.deps] field. As an example of dependency injection, let's modify the `DivisibleBy5` example [above](#graph) to use a [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] to run the compute load in a separate process (this is a contrived example, `ProcessPoolExecutor` wouldn't actually improve performance in this example): @@ -666,12 +763,12 @@ class GraphDeps: @dataclass -class DivisibleBy5(BaseNode[None, None, int]): +class DivisibleBy5(BaseNode[None, GraphDeps, int]): foo: int async def run( self, - ctx: GraphRunContext, + ctx: GraphRunContext[None, GraphDeps], ) -> Increment | End[int]: if self.foo % 5 == 0: return End(self.foo) @@ -680,10 +777,10 @@ class DivisibleBy5(BaseNode[None, None, int]): @dataclass -class Increment(BaseNode): +class Increment(BaseNode[None, GraphDeps]): foo: int - async def run(self, ctx: GraphRunContext) -> DivisibleBy5: + async def run(self, ctx: GraphRunContext[None, GraphDeps]) -> DivisibleBy5: loop = asyncio.get_running_loop() compute_result = await loop.run_in_executor( ctx.deps.executor, @@ -701,11 +798,11 @@ fives_graph = Graph(nodes=[DivisibleBy5, Increment]) async def main(): with ProcessPoolExecutor() as executor: deps = GraphDeps(executor) - result, history = await fives_graph.run(DivisibleBy5(3), deps=deps) - print(result) + result = await fives_graph.run(DivisibleBy5(3), deps=deps) + print(result.output) #> 5 # the full history is quite verbose (see below), so we'll just print the summary - print([item.data_snapshot() for item in history]) + print([item.data_snapshot() for item in result.history]) """ [ DivisibleBy5(foo=3), @@ -779,7 +876,7 @@ question_graph.mermaid_save('image.png', highlighted_nodes=[Answer]) _(This example is not complete and cannot be run directly)_ -Would generate and image that looks like this: +This would generate an image that looks like this: ```mermaid --- @@ -809,7 +906,7 @@ You can specify the direction of the state diagram using one of the following va - `'RL'`: Right to left, the diagram flows horizontally from right to left. - `'BT'`: Bottom to top, the diagram flows vertically from bottom to top. -Here is an example of how to do this using 'Left to Right' (LR) instead of the default 'Top to Bottom' (TB) +Here is an example of how to do this using 'Left to Right' (LR) instead of the default 'Top to Bottom' (TB): ```py {title="vending_machine_diagram.py" py="3.10"} from vending_machine import InsertCoin, vending_machine_graph diff --git a/docs/message-history.md b/docs/message-history.md index d538112f8..1fad6f54c 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -6,12 +6,12 @@ PydanticAI provides access to messages exchanged during an agent run. These mess After running an agent, you can access the messages exchanged during that run from the `result` object. -Both [`RunResult`][pydantic_ai.result.RunResult] +Both [`RunResult`][pydantic_ai.agent.AgentRunResult] (returned by [`Agent.run`][pydantic_ai.Agent.run], [`Agent.run_sync`][pydantic_ai.Agent.run_sync]) and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`Agent.run_stream`][pydantic_ai.Agent.run_stream]) have the following methods: -* [`all_messages()`][pydantic_ai.result.RunResult.all_messages]: returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.result.RunResult.all_messages_json]. -* [`new_messages()`][pydantic_ai.result.RunResult.new_messages]: returns only the messages from the current run. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.result.RunResult.new_messages_json]. +* [`all_messages()`][pydantic_ai.agent.AgentRunResult.all_messages]: returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.agent.AgentRunResult.all_messages_json]. +* [`new_messages()`][pydantic_ai.agent.AgentRunResult.new_messages]: returns only the messages from the current run. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.agent.AgentRunResult.new_messages_json]. !!! info "StreamedRunResult and complete messages" On [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], the messages returned from these methods will only include the final result message once the stream has finished. @@ -25,7 +25,7 @@ and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`A **Note:** The final result message will NOT be added to result messages if you use [`.stream_text(delta=True)`][pydantic_ai.result.StreamedRunResult.stream_text] since in this case the result content is never built as one string. -Example of accessing methods on a [`RunResult`][pydantic_ai.result.RunResult] : +Example of accessing methods on a [`RunResult`][pydantic_ai.agent.AgentRunResult] : ```python {title="run_result_messages.py" hl_lines="10 28"} from pydantic_ai import Agent diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 94913d1c4..002dd3c55 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -18,7 +18,7 @@ Since agents are stateless and designed to be global, you do not need to include You'll generally want to pass [`ctx.usage`][pydantic_ai.RunContext.usage] to the [`usage`][pydantic_ai.Agent.run] keyword argument of the delegate agent run so usage within that run counts towards the total usage of the parent agent run. !!! note "Multiple models" - Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final [`result.usage()`][pydantic_ai.result.RunResult.usage] of the run will not be possible, but you can still use [`UsageLimits`][pydantic_ai.usage.UsageLimits] to avoid unexpected costs. + Agent delegation doesn't need to use the same model for each agent. If you choose to use different models within a run, calculating the monetary cost from the final [`result.usage()`][pydantic_ai.agent.AgentRunResult.usage] of the run will not be possible, but you can still use [`UsageLimits`][pydantic_ai.usage.UsageLimits] to avoid unexpected costs. ```python {title="agent_delegation_simple.py"} from pydantic_ai import Agent, RunContext @@ -62,7 +62,7 @@ Usage( 1. The "parent" or controlling agent. 2. The "delegate" agent, which is called from within a tool of the parent agent. 3. Call the delegate agent from within a tool of the parent agent. -4. Pass the usage from the parent agent to the delegate agent so the final [`result.usage()`][pydantic_ai.result.RunResult.usage] includes the usage from both agents. +4. Pass the usage from the parent agent to the delegate agent so the final [`result.usage()`][pydantic_ai.agent.AgentRunResult.usage] includes the usage from both agents. 5. Since the function returns `#!python list[str]`, and the `result_type` of `joke_generation_agent` is also `#!python list[str]`, we can simply return `#!python r.data` from the tool. _(This example is complete, it can be run "as is")_ diff --git a/docs/results.md b/docs/results.md index e4e8a8c63..678048014 100644 --- a/docs/results.md +++ b/docs/results.md @@ -1,5 +1,5 @@ Results are the final values returned from [running an agent](agents.md#running-agents). -The result values are wrapped in [`RunResult`][pydantic_ai.result.RunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) +The result values are wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) Both `RunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 6f28e3047..1b77a420b 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -1,11 +1,15 @@ from importlib.metadata import version -from .agent import Agent, capture_run_messages +from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError from .tools import RunContext, Tool __all__ = ( 'Agent', + 'EndStrategy', + 'HandleResponseNode', + 'ModelRequestNode', + 'UserPromptNode', 'capture_run_messages', 'RunContext', 'Tool', diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 60a5b3f97..b080acfc0 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -3,7 +3,7 @@ import asyncio import dataclasses from abc import ABC -from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import field @@ -32,6 +32,16 @@ ToolDefinition, ) +__all__ = ( + 'GraphAgentState', + 'GraphAgentDeps', + 'UserPromptNode', + 'ModelRequestNode', + 'HandleResponseNode', + 'build_run_context', + 'capture_run_messages', +) + _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') # while waiting for https://github.com/pydantic/logfire/issues/745 @@ -56,21 +66,6 @@ ResultT = TypeVar('ResultT') -@dataclasses.dataclass -class MarkFinalResult(Generic[ResultDataT]): - """Marker class to indicate that the result is the final result. - - This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly. - - It also avoids problems in the case where the result type is itself `None`, but is set. - """ - - data: ResultDataT - """The final result data.""" - tool_name: str | None - """Name of the final result tool, None if the result is a string.""" - - @dataclasses.dataclass class GraphAgentState: """State kept across the execution of the agent graph.""" @@ -113,17 +108,22 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]): @dataclasses.dataclass -class BaseUserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC): +class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC): user_prompt: str system_prompts: tuple[str, ...] system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] + async def run( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + ) -> ModelRequestNode[DepsT, NodeRunEndT]: + return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) + async def _get_first_message( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] ) -> _messages.ModelRequest: - run_context = _build_run_context(ctx) + run_context = build_run_context(ctx) history, next_message = await self._prepare_messages(self.user_prompt, ctx.state.message_history, run_context) ctx.state.message_history = history run_context.messages = history @@ -188,29 +188,13 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod return messages -@dataclasses.dataclass -class UserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]): - async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] - ) -> ModelRequestNode[DepsT, NodeRunEndT]: - return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) - - -@dataclasses.dataclass -class StreamUserPromptNode(BaseUserPromptNode[DepsT, NodeRunEndT]): - async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] - ) -> StreamModelRequestNode[DepsT, NodeRunEndT]: - return StreamModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx)) - - async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" function_tool_defs: list[ToolDefinition] = [] - run_context = _build_run_context(ctx) + run_context = build_run_context(ctx) async def add_tool(tool: Tool[DepsT]) -> None: ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name) @@ -222,7 +206,7 @@ async def add_tool(tool: Tool[DepsT]) -> None: result_schema = ctx.deps.result_schema return models.ModelRequestParameters( function_tools=function_tool_defs, - allow_text_result=_allow_text_result(result_schema), + allow_text_result=allow_text_result(result_schema), result_tools=result_schema.tool_defs() if result_schema is not None else [], ) @@ -233,9 +217,70 @@ class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], Nod request: _messages.ModelRequest + _result: HandleResponseNode[DepsT, NodeRunEndT] | None = field(default=None, repr=False) + _did_stream: bool = field(default=False, repr=False) + async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> HandleResponseNode[DepsT, NodeRunEndT]: + if self._result is not None: + return self._result + + if self._did_stream: + # `self._result` gets set when exiting the `stream` contextmanager, so hitting this + # means that the stream was started but not finished before `run()` was called + raise exceptions.AgentRunError('You must finish streaming before calling run()') + + return await self._make_request(ctx) + + @asynccontextmanager + async def _stream( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], + ) -> AsyncIterator[models.StreamedResponse]: + # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public + assert not self._did_stream, 'stream() should only be called once per node' + + model_settings, model_request_parameters = await self._prepare_request(ctx) + with _logfire.span('model request', run_step=ctx.state.run_step) as span: + async with ctx.deps.model.request_stream( + ctx.state.message_history, model_settings, model_request_parameters + ) as streamed_response: + self._did_stream = True + ctx.state.usage.incr(_usage.Usage(), requests=1) + yield streamed_response + # In case the user didn't manually consume the full stream, ensure it is fully consumed here, + # otherwise usage won't be properly counted: + async for _ in streamed_response: + pass + model_response = streamed_response.get() + request_usage = streamed_response.usage() + span.set_attribute('response', model_response) + span.set_attribute('usage', request_usage) + + self._finish_handling(ctx, model_response, request_usage) + assert self._result is not None # this should be set by the previous line + + async def _make_request( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] + ) -> HandleResponseNode[DepsT, NodeRunEndT]: + if self._result is not None: + return self._result + + model_settings, model_request_parameters = await self._prepare_request(ctx) + with _logfire.span('model request', run_step=ctx.state.run_step) as span: + model_response, request_usage = await ctx.deps.model.request( + ctx.state.message_history, model_settings, model_request_parameters + ) + ctx.state.usage.incr(_usage.Usage(), requests=1) + span.set_attribute('response', model_response) + span.set_attribute('usage', request_usage) + + return self._finish_handling(ctx, model_response, request_usage) + + async def _prepare_request( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] + ) -> tuple[ModelSettings | None, models.ModelRequestParameters]: ctx.state.message_history.append(self.request) # Check usage @@ -245,71 +290,124 @@ async def run( # Increment run_step ctx.state.run_step += 1 + model_settings = merge_model_settings(ctx.deps.model_settings, None) with _logfire.span('preparing model request params {run_step=}', run_step=ctx.state.run_step): model_request_parameters = await _prepare_request_parameters(ctx) + return model_settings, model_request_parameters - # Actually make the model request - model_settings = merge_model_settings(ctx.deps.model_settings, None) - with _logfire.span('model request') as span: - model_response, request_usage = await ctx.deps.model.request( - ctx.state.message_history, model_settings, model_request_parameters - ) - span.set_attribute('response', model_response) - span.set_attribute('usage', request_usage) - + def _finish_handling( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + response: _messages.ModelResponse, + usage: _usage.Usage, + ) -> HandleResponseNode[DepsT, NodeRunEndT]: # Update usage - ctx.state.usage.incr(request_usage, requests=1) + ctx.state.usage.incr(usage, requests=0) if ctx.deps.usage_limits: ctx.deps.usage_limits.check_tokens(ctx.state.usage) # Append the model response to state.message_history - ctx.state.message_history.append(model_response) - return HandleResponseNode(model_response) + ctx.state.message_history.append(response) + + # Set the `_result` attribute since we can't use `return` in an async iterator + self._result = HandleResponseNode(response) + + return self._result @dataclasses.dataclass class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): - """Process e response from a model, decide whether to end the run or make a new request.""" + """Process a model response, and decide whether to end the run or make a new request.""" model_response: _messages.ModelResponse + _events_iterator: AsyncIterator[_messages.HandleResponseEvent] | None = field(default=None, repr=False) + _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( + default=None, repr=False + ) + _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) + async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], FinalResultNode[DepsT, NodeRunEndT]]: # noqa UP007 + ) -> Union[ModelRequestNode[DepsT, NodeRunEndT], End[result.FinalResult[NodeRunEndT]]]: # noqa UP007 + async with self.stream(ctx): + pass + + assert (next_node := self._next_node) is not None, 'the stream should set `self._next_node` before it ends' + return next_node + + @asynccontextmanager + async def stream( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + ) -> AsyncIterator[AsyncIterator[_messages.HandleResponseEvent]]: + """Process the model response and yield events for the start and end of each function tool call.""" with _logfire.span('handle model response', run_step=ctx.state.run_step) as handle_span: - texts: list[str] = [] - tool_calls: list[_messages.ToolCallPart] = [] - for part in self.model_response.parts: - if isinstance(part, _messages.TextPart): - # ignore empty content for text parts, see #437 - if part.content: - texts.append(part.content) - elif isinstance(part, _messages.ToolCallPart): - tool_calls.append(part) + stream = self._run_stream(ctx) + yield stream + + # Run the stream to completion if it was not finished: + async for _event in stream: + pass + + # Set the next node based on the final state of the stream + next_node = self._next_node + if isinstance(next_node, End): + handle_span.set_attribute('result', next_node.data) + handle_span.message = 'handle model response -> final result' + elif tool_responses := self._tool_responses: + # TODO: We could drop `self._tool_responses` if we drop this set_attribute + # I'm thinking it might be better to just create a span for the handling of each tool + # than to set an attribute here. + handle_span.set_attribute('tool_responses', tool_responses) + tool_responses_str = ' '.join(r.part_kind for r in tool_responses) + handle_span.message = f'handle model response -> {tool_responses_str}' + + async def _run_stream( + self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]] + ) -> AsyncIterator[_messages.HandleResponseEvent]: + if self._events_iterator is None: + # Ensure that the stream is only run once + + async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: + texts: list[str] = [] + tool_calls: list[_messages.ToolCallPart] = [] + for part in self.model_response.parts: + if isinstance(part, _messages.TextPart): + # ignore empty content for text parts, see #437 + if part.content: + texts.append(part.content) + elif isinstance(part, _messages.ToolCallPart): + tool_calls.append(part) + else: + assert_never(part) + + # At the moment, we prioritize at least executing tool calls if they are present. + # In the future, we'd consider making this configurable at the agent or run level. + # This accounts for cases like anthropic returns that might contain a text response + # and a tool call response, where the text response just indicates the tool call will happen. + if tool_calls: + async for event in self._handle_tool_calls(ctx, tool_calls): + yield event + elif texts: + # No events are emitted during the handling of text responses, so we don't need to yield anything + self._next_node = await self._handle_text_response(ctx, texts) else: - assert_never(part) - - # At the moment, we prioritize at least executing tool calls if they are present. - # In the future, we'd consider making this configurable at the agent or run level. - # This accounts for cases like anthropic returns that might contain a text response - # and a tool call response, where the text response just indicates the tool call will happen. - if tool_calls: - return await self._handle_tool_calls_response(ctx, tool_calls, handle_span) - elif texts: - return await self._handle_text_response(ctx, texts, handle_span) - else: - raise exceptions.UnexpectedModelBehavior('Received empty model response') + raise exceptions.UnexpectedModelBehavior('Received empty model response') - async def _handle_tool_calls_response( + self._events_iterator = _run_stream() + + async for event in self._events_iterator: + yield event + + async def _handle_tool_calls( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], tool_calls: list[_messages.ToolCallPart], - handle_span: logfire_api.LogfireSpan, - ): + ) -> AsyncIterator[_messages.HandleResponseEvent]: result_schema = ctx.deps.result_schema # first look for the result tool call - final_result: MarkFinalResult[NodeRunEndT] | None = None + final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] if result_schema is not None: if match := result_schema.find_tool(tool_calls): @@ -323,33 +421,51 @@ async def _handle_tool_calls_response( ctx.state.increment_retries(ctx.deps.max_result_retries) parts.append(e.tool_retry) else: - final_result = MarkFinalResult(result_data, call.tool_name) + final_result = result.FinalResult(result_data, call.tool_name) # Then build the other request parts based on end strategy - tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx) + tool_responses: list[_messages.ModelRequestPart] = self._tool_responses + async for event in process_function_tools( + tool_calls, final_result and final_result.tool_name, ctx, tool_responses + ): + yield event if final_result: - handle_span.set_attribute('result', final_result.data) - handle_span.message = 'handle model response -> final result' - return FinalResultNode[DepsT, NodeRunEndT](final_result, tool_responses) + self._next_node = self._handle_final_result(ctx, final_result, tool_responses) else: if tool_responses: - handle_span.set_attribute('tool_responses', tool_responses) - tool_responses_str = ' '.join(r.part_kind for r in tool_responses) - handle_span.message = f'handle model response -> {tool_responses_str}' parts.extend(tool_responses) - return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts)) + self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=parts)) + + def _handle_final_result( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + final_result: result.FinalResult[NodeRunEndT], + tool_responses: list[_messages.ModelRequestPart], + ) -> End[result.FinalResult[NodeRunEndT]]: + run_span = ctx.deps.run_span + usage = ctx.state.usage + messages = ctx.state.message_history + + # For backwards compatibility, append a new ModelRequest using the tool returns and retries + if tool_responses: + messages.append(_messages.ModelRequest(parts=tool_responses)) + + run_span.set_attribute('usage', usage) + run_span.set_attribute('all_messages', messages) + + # End the run with self.data + return End(final_result) async def _handle_text_response( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], texts: list[str], - handle_span: logfire_api.LogfireSpan, - ): + ) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]: result_schema = ctx.deps.result_schema text = '\n\n'.join(texts) - if _allow_text_result(result_schema): + if allow_text_result(result_schema): result_data_input = cast(NodeRunEndT, text) try: result_data = await _validate_result(result_data_input, ctx, None) @@ -357,9 +473,8 @@ async def _handle_text_response( ctx.state.increment_retries(ctx.deps.max_result_retries) return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: - handle_span.set_attribute('result', result_data) - handle_span.message = 'handle model response -> final result' - return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None)) + # The following cast is safe because we know `str` is an allowed result type + return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), []) else: ctx.state.increment_retries(ctx.deps.max_result_retries) return ModelRequestNode[DepsT, NodeRunEndT]( @@ -373,166 +488,8 @@ async def _handle_text_response( ) -@dataclasses.dataclass -class StreamModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT]): - """Make a request to the model using the last message in state.message_history (or a specified request).""" - - request: _messages.ModelRequest - _result: StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]] | None = ( - field(default=None, repr=False) - ) - - async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> Union[StreamModelRequestNode[DepsT, NodeRunEndT], End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: # noqa UP007 - if self._result is not None: - return self._result - - async with self.run_to_result(ctx) as final_node: - return final_node - - @asynccontextmanager - async def run_to_result( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]: - result_schema = ctx.deps.result_schema - - ctx.state.message_history.append(self.request) - - # Check usage - if ctx.deps.usage_limits: - ctx.deps.usage_limits.check_before_request(ctx.state.usage) - - # Increment run_step - ctx.state.run_step += 1 - - with _logfire.span('preparing model and tools {run_step=}', run_step=ctx.state.run_step): - model_request_parameters = await _prepare_request_parameters(ctx) - - # Actually make the model request - model_settings = merge_model_settings(ctx.deps.model_settings, None) - with _logfire.span('model request {run_step=}', run_step=ctx.state.run_step) as model_req_span: - async with ctx.deps.model.request_stream( - ctx.state.message_history, model_settings, model_request_parameters - ) as streamed_response: - ctx.state.usage.requests += 1 - model_req_span.set_attribute('response_type', streamed_response.__class__.__name__) - # We want to end the "model request" span here, but we can't exit the context manager - # in the traditional way - model_req_span.__exit__(None, None, None) - - with _logfire.span('handle model response') as handle_span: - received_text = False - - async for maybe_part_event in streamed_response: - if isinstance(maybe_part_event, _messages.PartStartEvent): - new_part = maybe_part_event.part - if isinstance(new_part, _messages.TextPart): - received_text = True - if _allow_text_result(result_schema): - handle_span.message = 'handle model response -> final result' - streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx) - self._result = End(streamed_run_result) - yield self._result - return - elif isinstance(new_part, _messages.ToolCallPart): - if result_schema is not None and (match := result_schema.find_tool([new_part])): - call, _ = match - handle_span.message = 'handle model response -> final result' - streamed_run_result = _build_streamed_run_result( - streamed_response, call.tool_name, ctx - ) - self._result = End(streamed_run_result) - yield self._result - return - else: - assert_never(new_part) - - tasks: list[asyncio.Task[_messages.ModelRequestPart]] = [] - parts: list[_messages.ModelRequestPart] = [] - model_response = streamed_response.get() - if not model_response.parts: - raise exceptions.UnexpectedModelBehavior('Received empty model response') - ctx.state.message_history.append(model_response) - - run_context = _build_run_context(ctx) - for p in model_response.parts: - if isinstance(p, _messages.ToolCallPart): - if tool := ctx.deps.function_tools.get(p.tool_name): - tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name)) - else: - parts.append(_unknown_tool(p.tool_name, ctx)) - - if received_text and not tasks and not parts: - # Can only get here if self._allow_text_result returns `False` for the provided result_schema - ctx.state.increment_retries(ctx.deps.max_result_retries) - self._result = StreamModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest( - parts=[ - _messages.RetryPromptPart( - content='Plain text responses are not permitted, please call one of the functions instead.', - ) - ] - ) - ) - yield self._result - return - - with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): - task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) - parts.extend(task_results) - - next_request = _messages.ModelRequest(parts=parts) - if any(isinstance(part, _messages.RetryPromptPart) for part in parts): - try: - ctx.state.increment_retries(ctx.deps.max_result_retries) - except: - # TODO: This is janky, so I think we should probably change it, but how? - ctx.state.message_history.append(next_request) - raise - - handle_span.set_attribute('tool_responses', parts) - tool_responses_str = ' '.join(r.part_kind for r in parts) - handle_span.message = f'handle model response -> {tool_responses_str}' - # the model_response should have been fully streamed by now, we can add its usage - streamed_response_usage = streamed_response.usage() - run_context.usage.incr(streamed_response_usage) - ctx.deps.usage_limits.check_tokens(run_context.usage) - self._result = StreamModelRequestNode[DepsT, NodeRunEndT](next_request) - yield self._result - return - - -@dataclasses.dataclass -class FinalResultNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[NodeRunEndT]]): - """Produce the final result of the run.""" - - data: MarkFinalResult[NodeRunEndT] - """The final result data.""" - extra_parts: list[_messages.ModelRequestPart] = dataclasses.field(default_factory=list) - - async def run( - self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] - ) -> End[MarkFinalResult[NodeRunEndT]]: - run_span = ctx.deps.run_span - usage = ctx.state.usage - messages = ctx.state.message_history - - # TODO: For backwards compatibility, append a new ModelRequest using the tool returns and retries - if self.extra_parts: - messages.append(_messages.ModelRequest(parts=self.extra_parts)) - - # TODO: Set this attribute somewhere - # handle_span = self.handle_model_response_span - # handle_span.set_attribute('final_data', self.data) - run_span.set_attribute('usage', usage) - run_span.set_attribute('all_messages', messages) - - # End the run with self.data - return End(self.data) - - -def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: +def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: + """Build a `RunContext` object from the current agent graph run context.""" return RunContext[DepsT]( deps=ctx.deps.user_deps, model=ctx.deps.model, @@ -543,76 +500,31 @@ def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Deps ) -def _build_streamed_run_result( - result_stream: models.StreamedResponse, - result_tool_name: str | None, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> result.StreamedRunResult[DepsT, NodeRunEndT]: - new_message_index = ctx.deps.new_message_index - result_schema = ctx.deps.result_schema - run_span = ctx.deps.run_span - usage_limits = ctx.deps.usage_limits - messages = ctx.state.message_history - run_context = _build_run_context(ctx) - - async def on_complete(): - """Called when the stream has completed. - - The model response will have been added to messages by now - by `StreamedRunResult._marked_completed`. - """ - last_message = messages[-1] - assert isinstance(last_message, _messages.ModelResponse) - tool_calls = [part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)] - parts = await _process_function_tools( - tool_calls, - result_tool_name, - ctx, - ) - # TODO: Should we do something here related to the retry count? - # Maybe we should move the incrementing of the retry count to where we actually make a request? - # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): - # ctx.state.increment_retries(ctx.deps.max_result_retries) - if parts: - messages.append(_messages.ModelRequest(parts)) - run_span.set_attribute('all_messages', messages) - - return result.StreamedRunResult[DepsT, NodeRunEndT]( - messages, - new_message_index, - usage_limits, - result_stream, - result_schema, - run_context, - ctx.deps.result_validators, - result_tool_name, - on_complete, - ) - - -async def _process_function_tools( +async def process_function_tools( tool_calls: list[_messages.ToolCallPart], result_tool_name: str | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> list[_messages.ModelRequestPart]: - """Process function (non-result) tool calls in parallel. + output_parts: list[_messages.ModelRequestPart], +) -> AsyncIterator[_messages.HandleResponseEvent]: + """Process function (i.e., non-result) tool calls in parallel. Also add stub return parts for any other tools that need it. - """ - parts: list[_messages.ModelRequestPart] = [] - tasks: list[asyncio.Task[_messages.ToolReturnPart | _messages.RetryPromptPart]] = [] + Because async iterators can't have return values, we use `output_parts` as an output argument. + """ stub_function_tools = bool(result_tool_name) and ctx.deps.end_strategy == 'early' result_schema = ctx.deps.result_schema # we rely on the fact that if we found a result, it's the first result tool in the last found_used_result_tool = False - run_context = _build_run_context(ctx) + run_context = build_run_context(ctx) + calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = [] + call_index_to_event_id: dict[int, str] = {} for call in tool_calls: if call.tool_name == result_tool_name and not found_used_result_tool: found_used_result_tool = True - parts.append( + output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, content='Final result processed.', @@ -621,7 +533,7 @@ async def _process_function_tools( ) elif tool := ctx.deps.function_tools.get(call.tool_name): if stub_function_tools: - parts.append( + output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, content='Tool not executed - a final result was already processed.', @@ -629,33 +541,47 @@ async def _process_function_tools( ) ) else: - tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name)) + event = _messages.FunctionToolCallEvent(call) + yield event + call_index_to_event_id[len(calls_to_run)] = event.call_id + calls_to_run.append((tool, call)) elif result_schema is not None and call.tool_name in result_schema.tools: # if tool_name is in _result_schema, it means we found a result tool but an error occurred in # validation, we don't add another part here if result_tool_name is not None: - parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Result tool not used - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Result tool not used - a final result was already processed.', + tool_call_id=call.tool_call_id, ) + output_parts.append(part) else: - parts.append(_unknown_tool(call.tool_name, ctx)) + output_parts.append(_unknown_tool(call.tool_name, ctx)) + + if not calls_to_run: + return # Run all tool tasks in parallel - if tasks: - with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): - task_results: Sequence[_messages.ToolReturnPart | _messages.RetryPromptPart] = await asyncio.gather(*tasks) - for result in task_results: - if isinstance(result, _messages.ToolReturnPart): - parts.append(result) - elif isinstance(result, _messages.RetryPromptPart): - parts.append(result) + results_by_index: dict[int, _messages.ModelRequestPart] = {} + with _logfire.span('running {tools=}', tools=[call.tool_name for _, call in calls_to_run]): + # TODO: Should we wrap each individual tool call in a dedicated span? + tasks = [asyncio.create_task(tool.run(call, run_context), name=call.tool_name) for tool, call in calls_to_run] + pending = tasks + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + index = tasks.index(task) + result = task.result() + yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index]) + if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)): + results_by_index[index] = result else: assert_never(result) - return parts + + # We append the results at the end, rather than as they are received, to retain a consistent ordering + # This is mostly just to simplify testing + for k in sorted(results_by_index): + output_parts.append(results_by_index[k]) def _unknown_tool( @@ -681,12 +607,13 @@ async def _validate_result( tool_call: _messages.ToolCallPart | None, ) -> T: for validator in ctx.deps.result_validators: - run_context = _build_run_context(ctx) + run_context = build_run_context(ctx) result_data = await validator.validate(result_data, tool_call, run_context) return result_data -def _allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool: +def allow_text_result(result_schema: _result.ResultSchema[Any] | None) -> bool: + """Check if the result schema allows text results.""" return result_schema is None or result_schema.allow_text_result @@ -740,35 +667,18 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( name: str | None, deps_type: type[DepsT], result_type: type[ResultT] -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]: - # We'll define the known node classes: +) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]: + """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( UserPromptNode[DepsT], ModelRequestNode[DepsT], HandleResponseNode[DepsT], - FinalResultNode[DepsT, ResultT], ) - graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], MarkFinalResult[ResultT]]( + graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[ResultT]]( nodes=nodes, name=name or 'Agent', state_type=GraphAgentState, - run_end_type=MarkFinalResult[result_type], + run_end_type=result.FinalResult[result_type], auto_instrument=False, ) return graph - - -def build_agent_stream_graph( - name: str | None, deps_type: type[DepsT], result_type: type[ResultT] | None -) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]: - nodes = [ - StreamUserPromptNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], - StreamModelRequestNode[DepsT, result.StreamedRunResult[DepsT, ResultT]], - ] - graph = Graph[GraphAgentState, GraphAgentDeps[DepsT, Any], result.StreamedRunResult[DepsT, Any]]( - nodes=nodes, - name=name or 'Agent', - state_type=GraphAgentState, - run_end_type=result.StreamedRunResult[DepsT, result_type], - ) - return graph diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index b2e01b1af..667727306 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -85,7 +85,7 @@ async def group_by_temporal( ) -> AsyncIterator[AsyncIterable[list[T]]]: """Group items from an async iterable into lists based on time interval between them. - Effectively debouncing the iterator. + Effectively, this debounces the iterator. This returns a context manager usable as an iterator so any pending tasks can be cancelled if an error occurs during iteration. diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 3501833d2..392443d73 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -5,14 +5,14 @@ import inspect from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager +from copy import deepcopy from types import FrameType from typing import Any, Callable, Generic, cast, final, overload import logfire_api from typing_extensions import TypeVar, deprecated -from pydantic_graph import Graph, GraphRunContext, HistoryStep -from pydantic_graph.nodes import End +from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext from . import ( _agent_graph, @@ -25,8 +25,7 @@ result, usage as _usage, ) -from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export -from .result import ResultDataT +from .result import FinalResult, ResultDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDepsT, @@ -40,7 +39,24 @@ ToolPrepareFunc, ) -__all__ = 'Agent', 'capture_run_messages', 'EndStrategy' +# Re-exporting like this improves auto-import behavior in PyCharm +capture_run_messages = _agent_graph.capture_run_messages +EndStrategy = _agent_graph.EndStrategy +HandleResponseNode = _agent_graph.HandleResponseNode +ModelRequestNode = _agent_graph.ModelRequestNode +UserPromptNode = _agent_graph.UserPromptNode + + +__all__ = ( + 'Agent', + 'AgentRun', + 'AgentRunResult', + 'capture_run_messages', + 'EndStrategy', + 'HandleResponseNode', + 'ModelRequestNode', + 'UserPromptNode', +) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') @@ -214,7 +230,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[ResultDataT]: ... + ) -> AgentRunResult[ResultDataT]: ... @overload async def run( @@ -229,23 +245,26 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[RunResultDataT]: ... + ) -> AgentRunResult[RunResultDataT]: ... async def run( self, user_prompt: str, *, + result_type: type[RunResultDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, - result_type: type[RunResultDataT] | None = None, infer_name: bool = True, - ) -> result.RunResult[Any]: + ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. + This method builds an internal agent graph (using system prompts, tools and result schemas) and then + runs the graph to completion. The result of the run is returned. + Example: ```python from pydantic_ai import Agent @@ -253,15 +272,115 @@ async def run( agent = Agent('openai:gpt-4o') async def main(): - result = await agent.run('What is the capital of France?') - print(result.data) + agent_run = await agent.run('What is the capital of France?') + print(agent_run.data) #> Paris ``` Args: + user_prompt: User input to start/continue the conversation. result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no result validators since result validators would expect an argument that matches the agent's result type. + message_history: History of the conversation so far. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + + Returns: + The result of the run. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + with self.iter( + user_prompt=user_prompt, + result_type=result_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + ) as agent_run: + async for _ in agent_run: + pass + + assert (final_result := agent_run.result) is not None, 'The graph run did not finish properly' + return final_result + + @contextmanager + def iter( + self, + user_prompt: str, + *, + result_type: type[RunResultDataT] | None = None, + message_history: list[_messages.ModelMessage] | None = None, + model: models.Model | models.KnownModelName | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: _usage.UsageLimits | None = None, + usage: _usage.Usage | None = None, + infer_name: bool = True, + ) -> Iterator[AgentRun[AgentDepsT, Any]]: + """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. + + This method builds an internal agent graph (using system prompts, tools and result schemas) and then returns an + `AgentRun` object. The `AgentRun` can be used to async-iterate over the nodes of the graph as they are + executed. This is the API to use if you want to consume the outputs coming from each LLM model response, or the + stream of events coming from the execution of tools. + + The `AgentRun` also provides methods to access the full message history, new messages, and usage statistics, + and the final result of the run once it has completed. + + For more details, see the documentation of `AgentRun`. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + nodes = [] + with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + nodes.append(node) + print(nodes) + ''' + [ + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', + ) + ), + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', + ) + ), + End(data=FinalResult(data='Paris', tool_name=None)), + ] + ''' + print(agent_run.result.data) + #> Paris + ``` + + Args: user_prompt: User input to start/continue the conversation. + result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no + result validators since result validators would expect an argument that matches the agent's result type. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. @@ -305,54 +424,44 @@ async def main(): model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or _usage.UsageLimits() - with _logfire.span( + # Build the deps object for the graph + run_span = _logfire.span( '{agent_name} run {prompt=}', prompt=user_prompt, agent=self, model_name=model_used.model_name if model_used else 'no-model', agent_name=self.name or 'agent', - ) as run_span: - # Build the deps object for the graph - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, - model_settings=model_settings, - usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - result_schema=result_schema, - result_tools=self._result_schema.tool_defs() if self._result_schema else [], - result_validators=result_validators, - function_tools=self._function_tools, - run_span=run_span, - ) - - start_node = _agent_graph.UserPromptNode[AgentDepsT]( - user_prompt=user_prompt, - system_prompts=self._system_prompts, - system_prompt_functions=self._system_prompt_functions, - system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, - ) - - # Actually run - end_result, _ = await graph.run( - start_node, - state=state, - deps=graph_deps, - infer_name=False, - ) - - # Build final run result - # We don't do any advanced checking if the data is actually from a final result or not - return result.RunResult( - state.message_history, - new_message_index, - end_result.data, - end_result.tool_name, - state.usage, ) + graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( + user_deps=deps, + prompt=user_prompt, + new_message_index=new_message_index, + model=model_used, + model_settings=model_settings, + usage_limits=usage_limits, + max_result_retries=self._max_result_retries, + end_strategy=self.end_strategy, + result_schema=result_schema, + result_tools=self._result_schema.tool_defs() if self._result_schema else [], + result_validators=result_validators, + function_tools=self._function_tools, + run_span=run_span, + ) + start_node = _agent_graph.UserPromptNode[AgentDepsT]( + user_prompt=user_prompt, + system_prompts=self._system_prompts, + system_prompt_functions=self._system_prompt_functions, + system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, + ) + + with graph.iter( + start_node, + state=state, + deps=graph_deps, + infer_name=False, + span=run_span, + ) as graph_run: + yield AgentRun(graph_run) @overload def run_sync( @@ -366,7 +475,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[ResultDataT]: ... + ) -> AgentRunResult[ResultDataT]: ... @overload def run_sync( @@ -381,7 +490,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[RunResultDataT]: ... + ) -> AgentRunResult[RunResultDataT]: ... def run_sync( self, @@ -395,8 +504,8 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, - ) -> result.RunResult[Any]: - """Run the agent with a user prompt synchronously. + ) -> AgentRunResult[Any]: + """Synchronously run the agent with a user prompt. This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. @@ -413,9 +522,9 @@ def run_sync( ``` Args: + user_prompt: User input to start/continue the conversation. result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no result validators since result validators would expect an argument that matches the agent's result type. - user_prompt: User input to start/continue the conversation. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. @@ -474,7 +583,7 @@ def run_stream( ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ... @asynccontextmanager - async def run_stream( + async def run_stream( # noqa C901 self, user_prompt: str, *, @@ -502,9 +611,9 @@ async def main(): ``` Args: + user_prompt: User input to start/continue the conversation. result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no result validators since result validators would expect an argument that matches the agent's result type. - user_prompt: User input to start/continue the conversation. message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. @@ -516,94 +625,104 @@ async def main(): Returns: The result of the run. """ + # TODO: We need to deprecate this now that we have the `iter` method. + # Before that, though, we should add an event for when we reach the final result of the stream. if infer_name and self.name is None: # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) - model_used = self._get_model(model) - deps = self._get_deps(deps) - new_message_index = len(message_history) if message_history else 0 - result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type) - - # Build the graph - graph = self._build_stream_graph(result_type) - - # Build the initial state - graph_state = _agent_graph.GraphAgentState( - message_history=message_history[:] if message_history else [], - usage=usage or _usage.Usage(), - retries=0, - run_step=0, - ) - - # We consider it a user error if a user tries to restrict the result type while having a result validator that - # may change the result type from the restricted type to something else. Therefore, we consider the following - # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators) - - # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent - # runs. Requires some changes to `Tool` to make them copyable though. - for v in self._function_tools.values(): - v.current_retry = 0 - - model_settings = merge_model_settings(self.model_settings, model_settings) - usage_limits = usage_limits or _usage.UsageLimits() - - with _logfire.span( - '{agent_name} run stream {prompt=}', - prompt=user_prompt, - agent=self, - model_name=model_used.model_name if model_used else 'no-model', - agent_name=self.name or 'agent', - ) as run_span: - # Build the deps object for the graph - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, - model_settings=model_settings, - usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - result_schema=result_schema, - result_tools=self._result_schema.tool_defs() if self._result_schema else [], - result_validators=result_validators, - function_tools=self._function_tools, - run_span=run_span, - ) - - start_node = _agent_graph.StreamUserPromptNode[AgentDepsT]( - user_prompt=user_prompt, - system_prompts=self._system_prompts, - system_prompt_functions=self._system_prompt_functions, - system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, - ) - - # Actually run - node = start_node - history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = [] + yielded = False + with self.iter( + user_prompt, + result_type=result_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=False, + ) as agent_run: + first_node = agent_run.next_node # start with the first node + assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node + node: BaseNode[Any, Any, Any] = cast(BaseNode[Any, Any, Any], first_node) while True: - if isinstance(node, _agent_graph.StreamModelRequestNode): - node = cast( - _agent_graph.StreamModelRequestNode[ - AgentDepsT, result.StreamedRunResult[AgentDepsT, RunResultDataT] - ], - node, - ) - async with node.run_to_result(GraphRunContext(graph_state, graph_deps)) as r: - if isinstance(r, End): - yield r.data + if isinstance(node, _agent_graph.ModelRequestNode): + node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node) + graph_ctx = agent_run.ctx + async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage] + + async def stream_to_final( + s: models.StreamedResponse, + ) -> FinalResult[models.StreamedResponse] | None: + result_schema = graph_ctx.deps.result_schema + async for maybe_part_event in streamed_response: + if isinstance(maybe_part_event, _messages.PartStartEvent): + new_part = maybe_part_event.part + if isinstance(new_part, _messages.TextPart): + if _agent_graph.allow_text_result(result_schema): + return FinalResult(s, None) + elif isinstance(new_part, _messages.ToolCallPart): + if result_schema is not None and (match := result_schema.find_tool([new_part])): + call, _ = match + return FinalResult(s, call.tool_name) + return None + + final_result_details = await stream_to_final(streamed_response) + if final_result_details is not None: + if yielded: + raise exceptions.AgentRunError('Agent run produced final results') + yielded = True + + messages = graph_ctx.state.message_history.copy() + + async def on_complete() -> None: + """Called when the stream has completed. + + The model response will have been added to messages by now + by `StreamedRunResult._marked_completed`. + """ + last_message = messages[-1] + assert isinstance(last_message, _messages.ModelResponse) + tool_calls = [ + part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) + ] + + parts: list[_messages.ModelRequestPart] = [] + async for _event in _agent_graph.process_function_tools( + tool_calls, + final_result_details.tool_name, + graph_ctx, + parts, + ): + pass + # TODO: Should we do something here related to the retry count? + # Maybe we should move the incrementing of the retry count to where we actually make a request? + # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): + # ctx.state.increment_retries(ctx.deps.max_result_retries) + if parts: + messages.append(_messages.ModelRequest(parts)) + + yield StreamedRunResult( + messages, + graph_ctx.deps.new_message_index, + graph_ctx.deps.usage_limits, + streamed_response, + graph_ctx.deps.result_schema, + _agent_graph.build_run_context(graph_ctx), + graph_ctx.deps.result_validators, + final_result_details.tool_name, + on_complete, + ) break - assert not isinstance(node, End) # the previous line should be hit first - node = await graph.next( - node, - history, - state=graph_state, - deps=graph_deps, - infer_name=False, - ) + next_node = await agent_run.next(node) + if not isinstance(next_node, BaseNode): + raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here') + node = cast(BaseNode[Any, Any, Any], next_node) + + if not yielded: + raise exceptions.AgentRunError('Agent run finished without producing a final result') @contextmanager def override( @@ -1039,14 +1158,9 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: def _build_graph( self, result_type: type[RunResultDataT] | None - ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]: + ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]]: return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type) - def _build_stream_graph( - self, result_type: type[RunResultDataT] | None - ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]: - return _agent_graph.build_agent_stream_graph(self.name, self._deps_type, result_type or self.result_type) - def _prepare_result_schema( self, result_type: type[RunResultDataT] | None ) -> _result.ResultSchema[RunResultDataT] | None: @@ -1058,3 +1172,314 @@ def _prepare_result_schema( ) else: return self._result_schema # pyright: ignore[reportReturnType] + + +@dataclasses.dataclass(repr=False) +class AgentRun(Generic[AgentDepsT, ResultDataT]): + """A stateful, async-iterable run of an [`Agent`][pydantic_ai.agent.Agent]. + + You generally obtain an `AgentRun` instance by calling `with my_agent.iter(...) as agent_run:`. + + Once you have an instance, you can use it to iterate through the run's nodes as they execute. When an + [`End`][pydantic_graph.nodes.End] is reached, the run finishes and [`result`][pydantic_ai.agent.AgentRun.result] + becomes available. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + + async def main(): + nodes = [] + # Iterate through the run, recording each node along the way: + with agent.iter('What is the capital of France?') as agent_run: + async for node in agent_run: + nodes.append(node) + print(nodes) + ''' + [ + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', + ) + ), + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', + ) + ), + End(data=FinalResult(data='Paris', tool_name=None)), + ] + ''' + print(agent_run.result.data) + #> Paris + ``` + + You can also manually drive the iteration using the [`next`][pydantic_ai.agent.AgentRun.next] method for + more granular control. + """ + + _graph_run: GraphRun[ + _agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT] + ] + + @property + def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]: + """The current context of the agent run.""" + return GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]( + self._graph_run.state, self._graph_run.deps + ) + + @property + def next_node( + self, + ) -> ( + BaseNode[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]] + | End[FinalResult[ResultDataT]] + ): + """The next node that will be run in the agent graph. + + This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. + """ + return self._graph_run.next_node + + @property + def result(self) -> AgentRunResult[ResultDataT] | None: + """The final result of the run if it has ended, otherwise `None`. + + Once the run returns an [`End`][pydantic_graph.nodes.End] node, `result` is populated + with an [`AgentRunResult`][pydantic_ai.agent.AgentRunResult]. + """ + graph_run_result = self._graph_run.result + if graph_run_result is None: + return None + return AgentRunResult( + graph_run_result.output.data, + graph_run_result.output.tool_name, + graph_run_result.state, + self._graph_run.deps.new_message_index, + ) + + def __aiter__( + self, + ) -> AsyncIterator[ + BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + FinalResult[ResultDataT], + ] + | End[FinalResult[ResultDataT]] + ]: + """Provide async-iteration over the nodes in the agent run.""" + return self + + async def __anext__( + self, + ) -> ( + BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + FinalResult[ResultDataT], + ] + | End[FinalResult[ResultDataT]] + ): + """Advance to the next node automatically based on the last returned node.""" + return await self._graph_run.__anext__() + + async def next( + self, + node: BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + FinalResult[ResultDataT], + ], + ) -> ( + BaseNode[ + _agent_graph.GraphAgentState, + _agent_graph.GraphAgentDeps[AgentDepsT, Any], + FinalResult[ResultDataT], + ] + | End[FinalResult[ResultDataT]] + ): + """Manually drive the agent run by passing in the node you want to run next. + + This lets you inspect or mutate the node before continuing execution, or skip certain nodes + under dynamic conditions. The agent run should be stopped when you return an [`End`][pydantic_graph.nodes.End] + node. + + Example: + ```python + from pydantic_ai import Agent + from pydantic_graph import End + + agent = Agent('openai:gpt-4o') + + async def main(): + with agent.iter('What is the capital of France?') as agent_run: + next_node = agent_run.next_node # start with the first node + nodes = [next_node] + while not isinstance(next_node, End): + next_node = await agent_run.next(next_node) + nodes.append(next_node) + # Once `next_node` is an End, we've finished: + print(nodes) + ''' + [ + UserPromptNode( + user_prompt='What is the capital of France?', + system_prompts=(), + system_prompt_functions=[], + system_prompt_dynamic_functions={}, + ), + ModelRequestNode( + request=ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of France?', + timestamp=datetime.datetime(...), + part_kind='user-prompt', + ) + ], + kind='request', + ) + ), + HandleResponseNode( + model_response=ModelResponse( + parts=[TextPart(content='Paris', part_kind='text')], + model_name='function:model_logic', + timestamp=datetime.datetime(...), + kind='response', + ) + ), + End(data=FinalResult(data='Paris', tool_name=None)), + ] + ''' + print('Final result:', agent_run.result.data) + #> Final result: Paris + ``` + + Args: + node: The node to run next in the graph. + + Returns: + The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if + the run has completed. + """ + # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it + # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. + return await self._graph_run.next(node) + + def usage(self) -> _usage.Usage: + """Get usage statistics for the run so far, including token usage, model requests, and so on.""" + return self._graph_run.state.usage + + def __repr__(self) -> str: + result = self._graph_run.result + result_repr = '' if result is None else repr(result.output) + return f'<{type(self).__name__} result={result_repr} usage={self.usage()}>' + + +@dataclasses.dataclass +class AgentRunResult(Generic[ResultDataT]): + """The final result of an agent run.""" + + data: ResultDataT # TODO: rename this to output. I'm putting this off for now mostly to reduce the size of the diff + + _result_tool_name: str | None = dataclasses.field(repr=False) + _state: _agent_graph.GraphAgentState = dataclasses.field(repr=False) + _new_message_index: int = dataclasses.field(repr=False) + + def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: + """Set return content for the result tool. + + Useful if you want to continue the conversation and want to set the response to the result tool call. + """ + if not self._result_tool_name: + raise ValueError('Cannot set result tool return content when the return type is `str`.') + messages = deepcopy(self._state.message_history) + last_message = messages[-1] + for part in last_message.parts: + if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name: + part.content = return_content + return messages + raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.') + + def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + """Return the history of _messages. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + List of messages. + """ + if result_tool_return_content is not None: + return self._set_result_tool_return(result_tool_return_content) + else: + return self._state.message_history + + def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: + """Return all messages from [`all_messages`][pydantic_ai.agent.AgentRunResult.all_messages] as JSON bytes. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + JSON bytes representing the messages. + """ + return _messages.ModelMessagesTypeAdapter.dump_json( + self.all_messages(result_tool_return_content=result_tool_return_content) + ) + + def new_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: + """Return new messages associated with this run. + + Messages from older runs are excluded. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + List of new messages. + """ + return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] + + def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: + """Return new messages from [`new_messages`][pydantic_ai.agent.AgentRunResult.new_messages] as JSON bytes. + + Args: + result_tool_return_content: The return content of the tool call to set in the last message. + This provides a convenient way to modify the content of the result tool call if you want to continue + the conversation and want to set the response to the result tool call. If `None`, the last message will + not be modified. + + Returns: + JSON bytes representing the new messages. + """ + return _messages.ModelMessagesTypeAdapter.dump_json( + self.new_messages(result_tool_return_content=result_tool_return_content) + ) + + def usage(self) -> _usage.Usage: + """Return the usage of the whole run.""" + return self._state.usage diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index d3001bf52..c6775c838 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import uuid from dataclasses import dataclass, field, replace from datetime import datetime from typing import Annotated, Any, Literal, Union, cast, overload @@ -445,3 +446,33 @@ class PartDeltaEvent: ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')] """An event in the model response stream, either starting a new part or applying a delta to an existing one.""" + + +@dataclass +class FunctionToolCallEvent: + """An event indicating the start to a call to a function tool.""" + + part: ToolCallPart + """The (function) tool call to make.""" + call_id: str = field(init=False) + """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id.""" + event_kind: Literal['function_tool_call'] = 'function_tool_call' + """Event type identifier, used as a discriminator.""" + + def __post_init__(self): + self.call_id = self.part.tool_call_id or str(uuid.uuid4()) + + +@dataclass +class FunctionToolResultEvent: + """An event indicating the result of a function tool call.""" + + result: ToolReturnPart | RetryPromptPart + """The result of the call to the function tool.""" + call_id: str + """An ID used to match the result to its original call.""" + event_kind: Literal['function_tool_result'] = 'function_tool_result' + """Event type identifier, used as a discriminator.""" + + +HandleResponseEvent = Annotated[Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('kind')] diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 9c694a859..eef023c97 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -234,6 +234,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: This method should be implemented by subclasses to translate the vendor-specific stream of events into pydantic_ai-format events. + + It should use the `_parts_manager` to handle deltas, and should update the `_usage` attributes as it goes. """ raise NotImplementedError() # noinspection PyUnreachableCode diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 4b69c634c..7646de5bf 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -1,8 +1,7 @@ from __future__ import annotations as _annotations -from abc import ABC, abstractmethod from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable -from copy import deepcopy +from copy import copy from dataclasses import dataclass, field from datetime import datetime from typing import Generic, Union, cast @@ -14,7 +13,7 @@ from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits -__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc', 'RunResult', 'StreamedRunResult' +__all__ = 'ResultDataT', 'ResultDataT_inv', 'ResultValidatorFunc' T = TypeVar('T') @@ -53,15 +52,34 @@ @dataclass -class _BaseRunResult(ABC, Generic[ResultDataT]): - """Base type for results. - - You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`. - """ +class StreamedRunResult(Generic[AgentDepsT, ResultDataT]): + """Result of a streamed run that returns structured data via a tool call.""" _all_messages: list[_messages.ModelMessage] _new_message_index: int + _usage_limits: UsageLimits | None + _stream_response: models.StreamedResponse + _result_schema: _result.ResultSchema[ResultDataT] | None + _run_ctx: RunContext[AgentDepsT] + _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] + _result_tool_name: str | None + _on_complete: Callable[[], Awaitable[None]] + + _initial_run_ctx_usage: Usage = field(init=False) + is_complete: bool = field(default=False, init=False) + """Whether the stream has all been received. + + This is set to `True` when one of + [`stream`][pydantic_ai.result.StreamedRunResult.stream], + [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text], + [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or + [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes. + """ + + def __post_init__(self): + self._initial_run_ctx_usage = copy(self._run_ctx.usage) + def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: """Return the history of _messages. @@ -80,7 +98,7 @@ def all_messages(self, *, result_tool_return_content: str | None = None) -> list return self._all_messages def all_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: - """Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes. + """Return all messages from [`all_messages`][pydantic_ai.result.StreamedRunResult.all_messages] as JSON bytes. Args: result_tool_return_content: The return content of the tool call to set in the last message. @@ -112,7 +130,7 @@ def new_messages(self, *, result_tool_return_content: str | None = None) -> list return self.all_messages(result_tool_return_content=result_tool_return_content)[self._new_message_index :] def new_messages_json(self, *, result_tool_return_content: str | None = None) -> bytes: - """Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes. + """Return new messages from [`new_messages`][pydantic_ai.result.StreamedRunResult.new_messages] as JSON bytes. Args: result_tool_return_content: The return content of the tool call to set in the last message. @@ -127,78 +145,6 @@ def new_messages_json(self, *, result_tool_return_content: str | None = None) -> self.new_messages(result_tool_return_content=result_tool_return_content) ) - @abstractmethod - def usage(self) -> Usage: - raise NotImplementedError() - - -@dataclass -class RunResult(_BaseRunResult[ResultDataT]): - """Result of a non-streamed run.""" - - data: ResultDataT - """Data from the final response in the run.""" - _result_tool_name: str | None - _usage: Usage - - def usage(self) -> Usage: - """Return the usage of the whole run.""" - return self._usage - - def all_messages(self, *, result_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: - """Return the history of _messages. - - Args: - result_tool_return_content: The return content of the tool call to set in the last message. - This provides a convenient way to modify the content of the result tool call if you want to continue - the conversation and want to set the response to the result tool call. If `None`, the last message will - not be modified. - - Returns: - List of messages. - """ - if result_tool_return_content is not None: - return self._set_result_tool_return(result_tool_return_content) - else: - return self._all_messages - - def _set_result_tool_return(self, return_content: str) -> list[_messages.ModelMessage]: - """Set return content for the result tool. - - Useful if you want to continue the conversation and want to set the response to the result tool call. - """ - if not self._result_tool_name: - raise ValueError('Cannot set result tool return content when the return type is `str`.') - messages = deepcopy(self._all_messages) - last_message = messages[-1] - for part in last_message.parts: - if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._result_tool_name: - part.content = return_content - return messages - raise LookupError(f'No tool call found with tool name {self._result_tool_name!r}.') - - -@dataclass -class StreamedRunResult(_BaseRunResult[ResultDataT], Generic[AgentDepsT, ResultDataT]): - """Result of a streamed run that returns structured data via a tool call.""" - - _usage_limits: UsageLimits | None - _stream_response: models.StreamedResponse - _result_schema: _result.ResultSchema[ResultDataT] | None - _run_ctx: RunContext[AgentDepsT] - _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] - _result_tool_name: str | None - _on_complete: Callable[[], Awaitable[None]] - is_complete: bool = field(default=False, init=False) - """Whether the stream has all been received. - - This is set to `True` when one of - [`stream`][pydantic_ai.result.StreamedRunResult.stream], - [`stream_text`][pydantic_ai.result.StreamedRunResult.stream_text], - [`stream_structured`][pydantic_ai.result.StreamedRunResult.stream_structured] or - [`get_data`][pydantic_ai.result.StreamedRunResult.get_data] completes. - """ - async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]: """Stream the response as an async iterable. @@ -234,61 +180,17 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = if self._result_schema and not self._result_schema.allow_text_result: raise exceptions.UserError('stream_text() can only be used with text responses') - usage_checking_stream = _get_usage_checking_stream_response( - self._stream_response, self._usage_limits, self.usage - ) - - # Define a "merged" version of the iterator that will yield items that have already been retrieved - # and items that we receive while streaming. We define a dedicated async iterator for this so we can - # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. - async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: - # if the response currently has any parts with content, yield those before streaming - msg = self._stream_response.get() - for i, part in enumerate(msg.parts): - if isinstance(part, _messages.TextPart) and part.content: - yield part.content, i - - async for event in usage_checking_stream: - if ( - isinstance(event, _messages.PartStartEvent) - and isinstance(event.part, _messages.TextPart) - and event.part.content - ): - yield event.part.content, event.index - elif ( - isinstance(event, _messages.PartDeltaEvent) - and isinstance(event.delta, _messages.TextPartDelta) - and event.delta.content_delta - ): - yield event.delta.content_delta, event.index - - async def _stream_text_deltas() -> AsyncIterator[str]: - async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: - async for items in group_iter: - yield ''.join([content for content, _ in items]) - with _logfire.span('response stream text') as lf_span: if delta: - async for text in _stream_text_deltas(): + async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): yield text else: - # a quick benchmark shows it's faster to build up a string with concat when we're - # yielding at each step - deltas: list[str] = [] combined_validated_text = '' - async for text in _stream_text_deltas(): - deltas.append(text) - combined_text = ''.join(deltas) - combined_validated_text = await self._validate_text_result(combined_text) + async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): + combined_validated_text = await self._validate_text_result(text) yield combined_validated_text - lf_span.set_attribute('combined_text', combined_validated_text) - await self._marked_completed( - _messages.ModelResponse( - parts=[_messages.TextPart(combined_validated_text)], - model_name=self._stream_response.model_name, - ) - ) + await self._marked_completed(self._stream_response.get()) async def stream_structured( self, *, debounce_by: float | None = 0.1 @@ -303,10 +205,6 @@ async def stream_structured( Returns: An async iterable of the structured response message and whether that is the last message. """ - usage_checking_stream = _get_usage_checking_stream_response( - self._stream_response, self._usage_limits, self.usage - ) - with _logfire.span('response stream structured') as lf_span: # if the message currently has any parts with content, yield before streaming msg = self._stream_response.get() @@ -315,15 +213,14 @@ async def stream_structured( yield msg, False break - async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: - async for _events in group_iter: - msg = self._stream_response.get() - yield msg, False - msg = self._stream_response.get() - yield msg, True - # TODO: Should this now be `final_response` instead of `structured_response`? - lf_span.set_attribute('structured_response', msg) - await self._marked_completed(msg) + async for msg in self._stream_response_structured(debounce_by=debounce_by): + yield msg, False + + msg = self._stream_response.get() + yield msg, True + + lf_span.set_attribute('structured_response', msg) + await self._marked_completed(msg) async def get_data(self) -> ResultDataT: """Stream the whole response, validate and return it.""" @@ -343,7 +240,7 @@ def usage(self) -> Usage: !!! note This won't return the full usage until the stream is finished. """ - return self._run_ctx.usage + self._stream_response.usage() + return self._initial_run_ctx_usage + self._stream_response.usage() def timestamp(self) -> datetime: """Get the timestamp of the response.""" @@ -391,6 +288,71 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: self._all_messages.append(message) await self._on_complete() + async def _stream_response_structured( + self, *, debounce_by: float | None = 0.1 + ) -> AsyncIterator[_messages.ModelResponse]: + async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: + async for _items in group_iter: + yield self._stream_response.get() + + async def _stream_response_text( + self, *, delta: bool = False, debounce_by: float | None = 0.1 + ) -> AsyncIterator[str]: + """Stream the response as an async iterable of text.""" + + # Define a "merged" version of the iterator that will yield items that have already been retrieved + # and items that we receive while streaming. We define a dedicated async iterator for this so we can + # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. + async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: + # yields tuples of (text_content, part_index) + # we don't currently make use of the part_index, but in principle this may be useful + # so we retain it here for now to make possible future refactors simpler + msg = self._stream_response.get() + for i, part in enumerate(msg.parts): + if isinstance(part, _messages.TextPart) and part.content: + yield part.content, i + + async for event in self._stream_response: + if ( + isinstance(event, _messages.PartStartEvent) + and isinstance(event.part, _messages.TextPart) + and event.part.content + ): + yield event.part.content, event.index + elif ( + isinstance(event, _messages.PartDeltaEvent) + and isinstance(event.delta, _messages.TextPartDelta) + and event.delta.content_delta + ): + yield event.delta.content_delta, event.index + + async def _stream_text_deltas() -> AsyncIterator[str]: + async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: + async for items in group_iter: + # Note: we are currently just dropping the part index on the group here + yield ''.join([content for content, _ in items]) + + if delta: + async for text in _stream_text_deltas(): + yield text + else: + # a quick benchmark shows it's faster to build up a string with concat when we're + # yielding at each step + deltas: list[str] = [] + async for text in _stream_text_deltas(): + deltas.append(text) + yield ''.join(deltas) + + +@dataclass +class FinalResult(Generic[ResultDataT]): + """Marker class storing the final result of an agent run and associated metadata.""" + + data: ResultDataT + """The final result data.""" + tool_name: str | None + """Name of the final result tool; `None` if the result came from unstructured text content.""" + def _get_usage_checking_stream_response( stream_response: AsyncIterable[_messages.ModelResponseStreamEvent], diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 15a4062e0..29b43cca9 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -50,10 +50,10 @@ class Increment(BaseNode): fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -result, history = fives_graph.run_sync(DivisibleBy5(4)) -print(result) +result = fives_graph.run_sync(DivisibleBy5(4)) +print(result.output) #> 5 # the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in history]) +print([item.data_snapshot() for item in result.history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index d4c6074e1..079325f59 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,10 +1,12 @@ from .exceptions import GraphRuntimeError, GraphSetupError -from .graph import Graph +from .graph import Graph, GraphRun, GraphRunResult from .nodes import BaseNode, Edge, End, GraphRunContext from .state import EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', + 'GraphRun', + 'GraphRunResult', 'BaseNode', 'End', 'GraphRunContext', diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index a670c3d39..a8f3897d4 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -3,17 +3,17 @@ import asyncio import inspect import types -from collections.abc import Sequence -from contextlib import ExitStack +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import ExitStack, contextmanager from dataclasses import dataclass, field from functools import cached_property -from pathlib import Path from time import perf_counter from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, TypeVar import logfire_api import pydantic import typing_extensions +from logfire_api import LogfireSpan from . import _utils, exceptions, mermaid from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT @@ -30,7 +30,7 @@ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) -__all__ = ('Graph',) +__all__ = ('Graph', 'GraphRun', 'GraphRunResult') _logfire = logfire_api.Logfire(otel_scope='pydantic-graph') @@ -133,7 +133,8 @@ async def run( state: StateT = None, deps: DepsT = None, infer_name: bool = True, - ) -> tuple[T, list[HistoryStep[StateT, T]]]: + span: LogfireSpan | None = None, + ) -> GraphRunResult[StateT, T]: """Run the graph from a starting node until it ends. Args: @@ -142,9 +143,11 @@ async def run( state: The initial state of the graph. deps: The dependencies of the graph. infer_name: Whether to infer the graph name from the calling frame. + span: The span to use for the graph run. If not provided, a span will be created depending on the value of + the `_auto_instrument` field. Returns: - The result type from ending the run and the history of the run. + A `GraphRunResult` containing information about the run, including its final result. Here's an example of running the graph from [above][pydantic_graph.graph.Graph]: @@ -153,50 +156,84 @@ async def run( async def main(): state = MyState(1) - _, history = await never_42_graph.run(Increment(), state=state) + graph_run_result = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) - print(len(history)) + print(len(graph_run_result.history)) #> 3 state = MyState(41) - _, history = await never_42_graph.run(Increment(), state=state) + graph_run_result = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) - print(len(history)) + print(len(graph_run_result.history)) #> 5 ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - history: list[HistoryStep[StateT, T]] = [] - with ExitStack() as stack: - run_span: logfire_api.LogfireSpan | None = None - if self._auto_instrument: - run_span = stack.enter_context( - _logfire.span( - '{graph_name} run {start=}', - graph_name=self.name or 'graph', - start=start_node, - ) - ) + with self.iter(start_node, state=state, deps=deps, infer_name=infer_name, span=span) as graph_run: + async for _node in graph_run: + pass + + final_result = graph_run.result + assert final_result is not None, 'GraphRun should have a final result' + return final_result + + @contextmanager + def iter( + self: Graph[StateT, DepsT, T], + start_node: BaseNode[StateT, DepsT, T], + *, + state: StateT = None, + deps: DepsT = None, + infer_name: bool = True, + span: LogfireSpan | None = None, + ) -> Iterator[GraphRun[StateT, DepsT, T]]: + """A contextmanager which can be used to iterate over the graph's nodes as they are executed. + + This method returns a `GraphRun` object which can be used to async-iterate over the nodes of this `Graph` as + they are executed. This is the API to use if you want to record or interact with the nodes as the graph + execution unfolds. + + The `GraphRun` can also be used to manually drive the graph execution by calling + [`GraphRun.next`][pydantic_graph.graph.GraphRun.next]. + + The `GraphRun` provides access to the full run history, state, deps, and the final result of the run once + it has completed. - next_node = start_node - while True: - next_node = await self.next(next_node, history, state=state, deps=deps, infer_name=False) - if isinstance(next_node, End): - history.append(EndStep(result=next_node)) - if run_span is not None: - run_span.set_attribute('history', history) - return next_node.data, history - elif not isinstance(next_node, BaseNode): - if TYPE_CHECKING: - typing_extensions.assert_never(next_node) - else: - raise exceptions.GraphRuntimeError( - f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' - ) + For more details, see the API documentation of [`GraphRun`][pydantic_graph.graph.GraphRun]. + + Args: + start_node: the first node to run. Since the graph definition doesn't define the entry point in the graph, + you need to provide the starting node. + state: The initial state of the graph. + deps: The dependencies of the graph. + infer_name: Whether to infer the graph name from the calling frame. + span: The span to use for the graph run. If not provided, a new span will be created. + + Yields: + A GraphRun that can be async iterated over to drive the graph to completion. + """ + if infer_name and self.name is None: + self._infer_name(inspect.currentframe()) + + if self._auto_instrument and span is None: + span = logfire_api.span('run graph {graph.name}', graph=self) + + with ExitStack() as stack: + if span is not None: + stack.enter_context(span) + yield GraphRun[StateT, DepsT, T]( + self, + start_node, + history=[], + state=state, + deps=deps, + auto_instrument=self._auto_instrument, + span=span, + ) def run_sync( self: Graph[StateT, DepsT, T], @@ -205,8 +242,8 @@ def run_sync( state: StateT = None, deps: DepsT = None, infer_name: bool = True, - ) -> tuple[T, list[HistoryStep[StateT, T]]]: - """Run the graph synchronously. + ) -> GraphRunResult[StateT, T]: + """Synchronously run the graph. This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. You therefore can't use this method inside async code or if there's an active event loop. @@ -250,6 +287,12 @@ async def next( """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) + + if isinstance(node, End): + # While technically this is not compatible with the documented method signature, it's an easy mistake to + # make, and we should eagerly provide a more helpful error message than you'd get otherwise. + raise exceptions.GraphRuntimeError(f'Cannot call `next` with an `End` node: {node!r}.') + node_id = node.get_id() if node_id not in self.node_defs: raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.') @@ -266,6 +309,17 @@ async def next( history.append( NodeStep(state=state, node=node, start_ts=start_ts, duration=duration, snapshot_state=self.snapshot_state) ) + + if isinstance(next_node, End): + history.append(EndStep(result=next_node)) + elif not isinstance(next_node, BaseNode): + if TYPE_CHECKING: + typing_extensions.assert_never(next_node) + else: + raise exceptions.GraphRuntimeError( + f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' + ) + return next_node def dump_history( @@ -336,7 +390,7 @@ def mermaid_code( Here's an example of generating a diagram for the graph from [above][pydantic_graph.graph.Graph]: - ```py {title="never_42.py" py="3.10"} + ```py {title="mermaid_never_42.py" py="3.10"} from never_42 import Increment, never_42_graph print(never_42_graph.mermaid_code(start_node=Increment)) @@ -510,3 +564,187 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None: if item is self: self.name = name return + + +class GraphRun(Generic[StateT, DepsT, RunEndT]): + """A stateful, async-iterable run of a [`Graph`][pydantic_graph.graph.Graph]. + + You typically get a `GraphRun` instance from calling + `with [my_graph.iter(...)][pydantic_graph.graph.Graph.iter] as graph_run:`. That gives you the ability to iterate + through nodes as they run, either by `async for` iteration or by repeatedly calling `.next(...)`. + + Here's an example of iterating over the graph from [above][pydantic_graph.graph.Graph]: + ```py {title="iter_never_42.py" noqa="I001" py="3.10"} + from copy import deepcopy + from never_42 import Increment, MyState, never_42_graph + + async def main(): + state = MyState(1) + with never_42_graph.iter(Increment(), state=state) as graph_run: + node_states = [(graph_run.next_node, deepcopy(graph_run.state))] + async for node in graph_run: + node_states.append((node, deepcopy(graph_run.state))) + print(node_states) + ''' + [ + (Increment(), MyState(number=1)), + (Check42(), MyState(number=2)), + (End(data=2), MyState(number=2)), + ] + ''' + + state = MyState(41) + with never_42_graph.iter(Increment(), state=state) as graph_run: + node_states = [(graph_run.next_node, deepcopy(graph_run.state))] + async for node in graph_run: + node_states.append((node, deepcopy(graph_run.state))) + print(node_states) + ''' + [ + (Increment(), MyState(number=41)), + (Check42(), MyState(number=42)), + (Increment(), MyState(number=42)), + (Check42(), MyState(number=43)), + (End(data=43), MyState(number=43)), + ] + ''' + ``` + + See the [`GraphRun.next` documentation][pydantic_graph.graph.GraphRun.next] for an example of how to manually + drive the graph run. + """ + + def __init__( + self, + graph: Graph[StateT, DepsT, RunEndT], + start_node: BaseNode[StateT, DepsT, RunEndT], + *, + history: list[HistoryStep[StateT, RunEndT]], + state: StateT, + deps: DepsT, + auto_instrument: bool, + span: LogfireSpan | None = None, + ): + """Create a new run for a given graph, starting at the specified node. + + Typically, you'll use [`Graph.iter`][pydantic_graph.graph.Graph.iter] rather than calling this directly. + + Args: + graph: The [`Graph`][pydantic_graph.graph.Graph] to run. + start_node: The node where execution will begin. + history: A list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects that describe + each step of the run. Usually starts empty; can be populated if resuming. + state: A shared state object or primitive (like a counter, dataclass, etc.) that is available + to all nodes via `ctx.state`. + deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections, + configuration, or logging clients. + auto_instrument: Whether to automatically create instrumentation spans during the run. + span: An optional existing Logfire span to nest node-level spans under (advanced usage). + """ + self.graph = graph + self.history = history + self.state = state + self.deps = deps + self._auto_instrument = auto_instrument + self._span = span + + self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node + + @property + def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: + """The next node that will be run in the graph. + + This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. + """ + return self._next_node + + @property + def result(self) -> GraphRunResult[StateT, RunEndT] | None: + """The final result of the graph run if the run is completed, otherwise `None`.""" + if not isinstance(self._next_node, End): + return None # The GraphRun has not finished running + return GraphRunResult( + self._next_node.data, + state=self.state, + history=self.history, + ) + + async def next( + self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] | None = None + ) -> BaseNode[StateT, DepsT, T] | End[T]: + """Manually drive the graph run by passing in the node you want to run next. + + This lets you inspect or mutate the node before continuing execution, or skip certain nodes + under dynamic conditions. The graph run should stop when you return an [`End`][pydantic_graph.nodes.End] node. + + Here's an example of using `next` to drive the graph from [above][pydantic_graph.graph.Graph]: + ```py {title="next_never_42.py" noqa="I001" py="3.10"} + from copy import deepcopy + from pydantic_graph import End + from never_42 import Increment, MyState, never_42_graph + + async def main(): + state = MyState(48) + with never_42_graph.iter(Increment(), state=state) as graph_run: + next_node = graph_run.next_node # start with the first node + node_states = [(next_node, deepcopy(graph_run.state))] + + while not isinstance(next_node, End): + if graph_run.state.number == 50: + graph_run.state.number = 42 + next_node = await graph_run.next(next_node) + node_states.append((next_node, deepcopy(graph_run.state))) + + print(node_states) + ''' + [ + (Increment(), MyState(number=48)), + (Check42(), MyState(number=49)), + (End(data=49), MyState(number=49)), + ] + ''' + ``` + + Args: + node: The node to run next in the graph. If not specified, uses `self.next_node`, which is initialized to + the `start_node` of the run and updated each time a new node is returned. + + Returns: + The next node returned by the graph logic, or an [`End`][pydantic_graph.nodes.End] node if + the run has completed. + """ + if node is None: + if isinstance(self._next_node, End): + # Note: we could alternatively just return `self._next_node` here, but it's easier to start with an + # error and relax the behavior later, than vice versa. + raise exceptions.GraphRuntimeError('This graph run has already ended.') + node = self._next_node + + history = self.history + state = self.state + deps = self.deps + + self._next_node = await self.graph.next(node, history, state=state, deps=deps, infer_name=False) + + return self._next_node + + def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]]: + return self + + async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: + """Use the last returned node as the input to `Graph.next`.""" + if isinstance(self._next_node, End): + raise StopAsyncIteration + return await self.next(self._next_node) + + def __repr__(self) -> str: + return f'"} step={len(self.history) + 1}>' + + +@dataclass +class GraphRunResult(Generic[StateT, RunEndT]): + """The final result of running a graph.""" + + output: RunEndT + state: StateT + history: list[HistoryStep[StateT, RunEndT]] = field(repr=False) diff --git a/pydantic_graph/pydantic_graph/nodes.py b/pydantic_graph/pydantic_graph/nodes.py index b43391ffe..f28106c97 100644 --- a/pydantic_graph/pydantic_graph/nodes.py +++ b/pydantic_graph/pydantic_graph/nodes.py @@ -28,6 +28,8 @@ class GraphRunContext(Generic[StateT, DepsT]): """Context for a graph.""" + # TODO: Can we get rid of this struct and just pass both these things around..? + state: StateT """The state of the graph.""" deps: DepsT diff --git a/pyproject.toml b/pyproject.toml index 73d82d39b..e2dba8002 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -189,4 +189,4 @@ skip = '.git*,*.svg,*.lock,*.css' check-hidden = true # Ignore "formatting" like **L**anguage ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b' -# ignore-words-list = '' +ignore-words-list = 'asend' diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index ebd254a37..91b7a4400 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -57,11 +57,11 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # assert my_graph.name is None assert my_graph._get_state_type() is type(None) assert my_graph._get_run_end_type() is int - result, history = await my_graph.run(Float2String(3.14)) + result = await my_graph.run(Float2String(3.14)) # len('3.14') * 2 == 8 - assert result == 8 + assert result.output == 8 assert my_graph.name == 'my_graph' - assert history == snapshot( + assert result.history == snapshot( [ NodeStep( state=None, @@ -84,10 +84,10 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)), ] ) - result, history = await my_graph.run(Float2String(3.14159)) + result = await my_graph.run(Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 - assert result == 42 - assert history == snapshot( + assert result.output == 42 + assert result.history == snapshot( [ NodeStep( state=None, @@ -122,7 +122,7 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # EndStep(result=End(data=42), ts=IsNow(tz=timezone.utc)), ] ) - assert [e.data_snapshot() for e in history] == snapshot( + assert [e.data_snapshot() for e in result.history] == snapshot( [ Float2String(input_data=3.14159), String2Length(input_data='3.14159'), @@ -320,10 +320,10 @@ async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]: return End(123) g = Graph(nodes=(Foo, Bar)) - result, history = await g.run(Foo(), deps=Deps(1, 2)) + result = await g.run(Foo(), deps=Deps(1, 2)) - assert result == 123 - assert history == snapshot( + assert result.output == 123 + assert result.history == snapshot( [ NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index 2508a5347..da4bcd0d7 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -46,16 +46,17 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: ], ) async def test_dump_load_history(graph: Graph[MyState, None, int]): - result, history = await graph.run(Foo(), state=MyState(1, '')) - assert result == snapshot(4) - assert history == snapshot( + result = await graph.run(Foo(), state=MyState(1, '')) + assert result.output == snapshot(4) + assert result.state == snapshot(MyState(x=2, y='y')) + assert result.history == snapshot( [ NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), EndStep(result=End(4), ts=IsNow(tz=timezone.utc)), ] ) - history_json = graph.dump_history(history) + history_json = graph.dump_history(result.history) assert json.loads(history_json) == snapshot( [ { @@ -76,7 +77,7 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): ] ) history_loaded = graph.load_history(history_json) - assert history == history_loaded + assert result.history == history_loaded custom_history = [ { diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 9f76d93cd..46fb88992 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -58,9 +58,9 @@ async def run(self, ctx: GraphRunContext) -> Annotated[End[None], Edge(label='eg async def test_run_graph(): - result, history = await graph1.run(Foo()) - assert result is None - assert history == snapshot( + result = await graph1.run(Foo()) + assert result.output is None + assert result.history == snapshot( [ NodeStep( state=None, diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index fbb570cf0..77435a1b8 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -36,9 +36,9 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: assert graph._get_state_type() is MyState assert graph._get_run_end_type() is str state = MyState(1, '') - result, history = await graph.run(Foo(), state=state) - assert result == snapshot('x=2 y=y') - assert history == snapshot( + result = await graph.run(Foo(), state=state) + assert result.output == snapshot('x=2 y=y') + assert result.history == snapshot( [ NodeStep( state=MyState(x=2, y=''), diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index aa3260be0..1e4281e08 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -816,6 +816,10 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): m = GeminiModel('gemini-1.5-flash', http_client=gemini_client) agent = Agent(m) + @agent.tool_plain() + def get_location(loc_name: str) -> str: + return f'Location for {loc_name}' + async with agent.run_stream('Hello') as result: data = await result.get_data() diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index dbc68c4b9..68a8d3b94 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -1633,7 +1633,7 @@ async def get_location(loc_name: str) -> str: ModelResponse( parts=[TextPart(content='final response')], model_name='mistral-large-latest', - timestamp=IsNow(tz=timezone.utc), + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), ), ] ) diff --git a/tests/test_agent.py b/tests/test_agent.py index 1a1091959..7d4b41c7f 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -27,7 +27,7 @@ from pydantic_ai.models import cached_async_http_client from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import RunResult, Usage +from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition from .conftest import IsNow, TestEnv @@ -534,36 +534,36 @@ async def ret_a(x: str) -> str: # if we pass new_messages, system prompt is inserted before the message_history messages result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) - assert result2 == snapshot( - RunResult( - _all_messages=[ - ModelRequest( - parts=[ - SystemPromptPart(content='Foobar'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), - ] - ), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] - ), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ], - _new_message_index=4, - data='{"ret_a":"a-apple"}', - _result_tool_name=None, - _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), - ) + assert result2.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] + ), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ] + ) + assert result2._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] + assert result2.data == snapshot('{"ret_a":"a-apple"}') + assert result2._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] + assert result2.usage() == snapshot( + Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( @@ -582,36 +582,36 @@ async def ret_a(x: str) -> str: # so only one system prompt result3 = agent.run_sync('Hello again', message_history=result1.all_messages()) # same as result2 except for datetimes - assert result3 == snapshot( - RunResult( - _all_messages=[ - ModelRequest( - parts=[ - SystemPromptPart(content='Foobar'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), - ] - ), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] - ), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) - ), - ], - _new_message_index=4, - data='{"ret_a":"a-apple"}', - _result_tool_name=None, - _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), - ) + assert result3.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] + ), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[TextPart(content='{"ret_a":"a-apple"}')], model_name='test', timestamp=IsNow(tz=timezone.utc) + ), + ] + ) + assert result3._new_message_index == snapshot(4) # pyright: ignore[reportPrivateUsage] + assert result3.data == snapshot('{"ret_a":"a-apple"}') + assert result3._result_tool_name == snapshot(None) # pyright: ignore[reportPrivateUsage] + assert result3.usage() == snapshot( + Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None) ) @@ -666,63 +666,63 @@ async def ret_a(x: str) -> str: ) result2 = agent.run_sync('Hello again', message_history=result1.new_messages()) - assert result2 == snapshot( - RunResult( - data=Response(a=0), - _all_messages=[ - ModelRequest( - parts=[ - SystemPromptPart(content='Foobar'), - UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), - ], - ), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))], - ), - ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', - content='Final result processed.', - timestamp=IsNow(tz=timezone.utc), - ), - ], - ), - # second call, notice no repeated system prompt - ModelRequest( - parts=[ - UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)), - ], - ), - ModelResponse( - parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='final_result', - content='Final result processed.', - timestamp=IsNow(tz=timezone.utc), - ), - ] - ), - ], - _new_message_index=5, - _result_tool_name='final_result', - _usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None), - ) + assert result2.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='Foobar'), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ], + ), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))], + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ], + ), + # second call, notice no repeated system prompt + ModelRequest( + parts=[ + UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)), + ], + ), + ModelResponse( + parts=[ToolCallPart(tool_name='final_result', args={'a': 0})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + timestamp=IsNow(tz=timezone.utc), + ), + ] + ), + ] + ) + assert result2.data == snapshot(Response(a=0)) + assert result2._new_message_index == snapshot(5) # pyright: ignore[reportPrivateUsage] + assert result2._result_tool_name == snapshot('final_result') # pyright: ignore[reportPrivateUsage] + assert result2.usage() == snapshot( + Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] assert new_msg_part_kinds == snapshot( diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 1725c4d36..f95be4c13 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -335,14 +335,18 @@ async def test_call_tool_wrong_name(): async def stream_structured_function(_messages: list[ModelMessage], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]: yield {0: DeltaToolCall(name='foobar', json_args='{}')} - agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int]) + agent = Agent( + FunctionModel(stream_function=stream_structured_function), + result_type=tuple[str, int], + retries=0, + ) @agent.tool_plain async def ret_a(x: str) -> str: # pragma: no cover return x with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for result validation'): async with agent.run_stream('hello'): pass # pragma: no cover @@ -354,14 +358,6 @@ async def ret_a(x: str) -> str: # pragma: no cover model_name='function:stream_structured_function', timestamp=IsNow(tz=timezone.utc), ), - ModelRequest( - parts=[ - RetryPromptPart( - content="Unknown tool name: 'foobar'. Available tools: ret_a, final_result", - timestamp=IsNow(tz=timezone.utc), - ) - ] - ), ] ) diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index e1d0234e0..ba00a3f01 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -76,34 +76,38 @@ async def test_streamed_text_limits() -> None: async def ret_a(x: str) -> str: return f'{x}-apple' - async with test_agent.run_stream('Hello', usage_limits=UsageLimits(response_tokens_limit=10)) as result: - assert test_agent.name == 'test_agent' - assert not result.is_complete - assert result.all_messages() == snapshot( - [ - ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), - ModelResponse( - parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], - model_name='test', - timestamp=IsNow(tz=timezone.utc), - ), - ModelRequest( - parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] - ), - ] - ) - assert result.usage() == snapshot( - Usage( - requests=2, - request_tokens=103, - response_tokens=5, - total_tokens=108, + succeeded = False + + with pytest.raises( + UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)') + ): + async with test_agent.run_stream('Hello', usage_limits=UsageLimits(response_tokens_limit=10)) as result: + assert test_agent.name == 'test_agent' + assert not result.is_complete + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args={'x': 'a'})], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] + ), + ] ) - ) - with pytest.raises( - UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)') - ): - await result.get_data() + assert result.usage() == snapshot( + Usage( + requests=2, + request_tokens=103, + response_tokens=5, + total_tokens=108, + ) + ) + succeeded = True + + assert succeeded def test_usage_so_far() -> None: diff --git a/tests/typed_agent.py b/tests/typed_agent.py index fdf9f1a25..280d0795a 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -8,7 +8,7 @@ from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool -from pydantic_ai.result import RunResult +from pydantic_ai.agent import AgentRunResult from pydantic_ai.tools import ToolDefinition @@ -139,7 +139,7 @@ async def result_validator_wrong(ctx: RunContext[int], result: str) -> str: def run_sync() -> None: result = typed_agent.run_sync('testing', deps=MyDeps(foo=1, bar=2)) - assert_type(result, RunResult[str]) + assert_type(result, AgentRunResult[str]) assert_type(result.data, str) @@ -176,7 +176,7 @@ class Bar: def run_sync3() -> None: result = union_agent.run_sync('testing') - assert_type(result, RunResult[Union[Foo, Bar]]) + assert_type(result, AgentRunResult[Union[Foo, Bar]]) assert_type(result.data, Union[Foo, Bar]) diff --git a/tests/typed_graph.py b/tests/typed_graph.py index d0b6a02b7..4540ac608 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -109,6 +109,6 @@ def run_g5() -> None: g5.run_sync(A()) # pyright: ignore[reportArgumentType] g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType] g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType] - answer, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) - assert_type(answer, int) - assert_type(history, list[HistoryStep[MyState, int]]) + result = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) + assert_type(result.output, int) + assert_type(result.history, list[HistoryStep[MyState, int]])