Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug related to handling multiple result tools #926

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class MarkFinalResult(Generic[ResultDataT]):
"""The final result data."""
tool_name: str | None
"""Name of the final result tool, None if the result is a string."""
tool_call_id: str | None


@dataclasses.dataclass
Expand Down Expand Up @@ -312,8 +313,7 @@ async def _handle_tool_calls_response(
final_result: MarkFinalResult[NodeRunEndT] | None = None
parts: list[_messages.ModelRequestPart] = []
if result_schema is not None:
if match := result_schema.find_tool(tool_calls):
call, result_tool = match
for call, result_tool in result_schema.find_tool(tool_calls):
try:
result_data = result_tool.validate(call)
result_data = await _validate_result(result_data, ctx, call)
Expand All @@ -323,10 +323,13 @@ async def _handle_tool_calls_response(
ctx.state.increment_retries(ctx.deps.max_result_retries)
parts.append(e.tool_retry)
else:
final_result = MarkFinalResult(result_data, call.tool_name)
final_result = MarkFinalResult(result_data, call.tool_name, call.tool_call_id)
break

# Then build the other request parts based on end strategy
tool_responses = await _process_function_tools(tool_calls, final_result and final_result.tool_name, ctx)
tool_responses = await _process_function_tools(
tool_calls, final_result and final_result.tool_name, final_result and final_result.tool_call_id, ctx
)

if final_result:
handle_span.set_attribute('result', final_result.data)
Expand Down Expand Up @@ -359,7 +362,7 @@ async def _handle_text_response(
else:
handle_span.set_attribute('result', result_data)
handle_span.message = 'handle model response -> final result'
return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None))
return FinalResultNode[DepsT, NodeRunEndT](MarkFinalResult(result_data, None, None))
else:
ctx.state.increment_retries(ctx.deps.max_result_retries)
return ModelRequestNode[DepsT, NodeRunEndT](
Expand Down Expand Up @@ -392,7 +395,7 @@ async def run(
return final_node

@asynccontextmanager
async def run_to_result(
async def run_to_result( # noqa C901
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> AsyncIterator[StreamModelRequestNode[DepsT, NodeRunEndT] | End[result.StreamedRunResult[DepsT, NodeRunEndT]]]:
result_schema = ctx.deps.result_schema
Expand Down Expand Up @@ -431,20 +434,21 @@ async def run_to_result(
received_text = True
if _allow_text_result(result_schema):
handle_span.message = 'handle model response -> final result'
streamed_run_result = _build_streamed_run_result(streamed_response, None, ctx)
streamed_run_result = _build_streamed_run_result(streamed_response, None, None, ctx)
self._result = End(streamed_run_result)
yield self._result
return
elif isinstance(new_part, _messages.ToolCallPart):
if result_schema is not None and (match := result_schema.find_tool([new_part])):
call, _ = match
handle_span.message = 'handle model response -> final result'
streamed_run_result = _build_streamed_run_result(
streamed_response, call.tool_name, ctx
)
self._result = End(streamed_run_result)
yield self._result
return
if result_schema is not None:
for call, _ in result_schema.find_tool([new_part]):
# Note: this ignores anything after the first tool call
handle_span.message = 'handle model response -> final result'
streamed_run_result = _build_streamed_run_result(
streamed_response, call.tool_name, call.tool_call_id, ctx
)
self._result = End(streamed_run_result)
yield self._result
return
else:
assert_never(new_part)

Expand Down Expand Up @@ -546,6 +550,7 @@ def _build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[Deps
def _build_streamed_run_result(
result_stream: models.StreamedResponse,
result_tool_name: str | None,
result_tool_call_id: str | None,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
) -> result.StreamedRunResult[DepsT, NodeRunEndT]:
new_message_index = ctx.deps.new_message_index
Expand All @@ -567,6 +572,7 @@ async def on_complete():
parts = await _process_function_tools(
tool_calls,
result_tool_name,
result_tool_call_id,
ctx,
)
# TODO: Should we do something here related to the retry count?
Expand All @@ -593,6 +599,7 @@ async def on_complete():
async def _process_function_tools(
tool_calls: list[_messages.ToolCallPart],
result_tool_name: str | None,
result_tool_call_id: str | None,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
) -> list[_messages.ModelRequestPart]:
"""Process function (non-result) tool calls in parallel.
Expand All @@ -610,7 +617,11 @@ async def _process_function_tools(
run_context = _build_run_context(ctx)

for call in tool_calls:
if call.tool_name == result_tool_name and not found_used_result_tool:
if (
call.tool_name == result_tool_name
and call.tool_call_id == result_tool_call_id
and not found_used_result_tool
):
found_used_result_tool = True
parts.append(
_messages.ToolReturnPart(
Expand All @@ -634,10 +645,15 @@ async def _process_function_tools(
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
# validation, we don't add another part here
if result_tool_name is not None:
if found_used_result_tool:
content = 'Result tool not used - a final result was already processed.'
else:
# TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
content = 'Result tool not used - result failed validation.'
parts.append(
_messages.ToolReturnPart(
tool_name=call.tool_name,
content='Result tool not used - a final result was already processed.',
content=content,
tool_call_id=call.tool_call_id,
)
)
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import sys
import types
from collections.abc import Awaitable, Iterable
from collections.abc import Awaitable, Iterable, Iterator
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin

Expand Down Expand Up @@ -127,12 +127,12 @@ def find_named_tool(
def find_tool(
self,
parts: Iterable[_messages.ModelResponsePart],
) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
) -> Iterator[tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]]:
"""Find a tool that matches one of the calls."""
for part in parts:
if isinstance(part, _messages.ToolCallPart):
if result := self.tools.get(part.tool_name):
return part, result
yield part, result

def tool_names(self) -> list[str]:
"""Return the names of the tools."""
Expand Down
36 changes: 36 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,42 @@ def regular_tool(x: int) -> int:
tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturnPart)]
assert tool_returns == snapshot([])

def test_multiple_final_result_are_validated_correctly(self):
"""Tests that if multiple final results are returned, but one fails validation, the other is used."""

def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
return ModelResponse(
parts=[
ToolCallPart('final_result', {'bad_value': 'first'}, tool_call_id='first'),
ToolCallPart('final_result', {'value': 'second'}, tool_call_id='second'),
]
)

agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early')
result = agent.run_sync('test multiple final results')

# Verify the result came from the second final tool
assert result.data.value == 'second'

# Verify we got appropriate tool returns
assert result.new_messages()[-1].parts == snapshot(
[
ToolReturnPart(
tool_name='final_result',
tool_call_id='first',
content='Result tool not used - result failed validation.',
timestamp=IsNow(tz=timezone.utc),
),
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
timestamp=IsNow(tz=timezone.utc),
tool_call_id='second',
),
]
)


async def test_model_settings_override() -> None:
def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
Expand Down