diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 2cfa270e0..7674597aa 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -74,6 +74,7 @@ MCPToolApprovalRequest, Tool, ) +from .tool_context import ToolContext from .tracing import ( SpanError, Trace, @@ -539,23 +540,24 @@ async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> Any: with function_span(func_tool.name) as span_fn: + tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id) if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: _, _, result = await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, func_tool), + hooks.on_tool_start(tool_context, agent, func_tool), ( - agent.hooks.on_tool_start(context_wrapper, agent, func_tool) + agent.hooks.on_tool_start(tool_context, agent, func_tool) if agent.hooks else _coro.noop_coroutine() ), - func_tool.on_invoke_tool(context_wrapper, tool_call.arguments), + func_tool.on_invoke_tool(tool_context, tool_call.arguments), ) await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, func_tool, result), + hooks.on_tool_end(tool_context, agent, func_tool, result), ( - agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result) + agent.hooks.on_tool_end(tool_context, agent, func_tool, result) if agent.hooks else _coro.noop_coroutine() ), diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index 0e5868965..188566970 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -13,6 +13,7 @@ from .exceptions import UserError from .run_context import RunContextWrapper from .strict_schema import ensure_strict_json_schema +from .tool_context import ToolContext @dataclass @@ -237,21 +238,21 @@ def function_schema( ann = type_hints.get(first_name, first_param.annotation) if ann != inspect._empty: origin = get_origin(ann) or ann - if origin is RunContextWrapper: + if origin is RunContextWrapper or origin is ToolContext: takes_context = True # Mark that the function takes context else: filtered_params.append((first_name, first_param)) else: filtered_params.append((first_name, first_param)) - # For parameters other than the first, raise error if any use RunContextWrapper. + # For parameters other than the first, raise error if any use RunContextWrapper or ToolContext. for name, param in params[1:]: ann = type_hints.get(name, param.annotation) if ann != inspect._empty: origin = get_origin(ann) or ann - if origin is RunContextWrapper: + if origin is RunContextWrapper or origin is ToolContext: raise UserError( - f"RunContextWrapper param found at non-first position in function" + f"RunContextWrapper/ToolContext param found at non-first position in function" f" {func.__name__}" ) filtered_params.append((name, param)) diff --git a/src/agents/tool.py b/src/agents/tool.py index fd5a21c89..3cd948aa0 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -20,6 +20,7 @@ from .items import RunItem from .logger import logger from .run_context import RunContextWrapper +from .tool_context import ToolContext from .tracing import SpanError from .util import _error_tracing from .util._types import MaybeAwaitable @@ -28,8 +29,13 @@ ToolFunctionWithoutContext = Callable[ToolParams, Any] ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParams], Any] +ToolFunctionWithToolContext = Callable[Concatenate[ToolContext, ToolParams], Any] -ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]] +ToolFunction = Union[ + ToolFunctionWithoutContext[ToolParams], + ToolFunctionWithContext[ToolParams], + ToolFunctionWithToolContext[ToolParams], +] @dataclass @@ -59,7 +65,7 @@ class FunctionTool: params_json_schema: dict[str, Any] """The JSON schema for the tool's parameters.""" - on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]] + on_invoke_tool: Callable[[ToolContext[Any], str], Awaitable[Any]] """A function that invokes the tool with the given context and parameters. The params passed are: 1. The tool run context. @@ -330,7 +336,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: strict_json_schema=strict_mode, ) - async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any: + async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: try: json_data: dict[str, Any] = json.loads(input) if input else {} except Exception as e: @@ -379,7 +385,7 @@ async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any: return result - async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any: + async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any: try: return await _on_invoke_tool_impl(ctx, input) except Exception as e: diff --git a/src/agents/tool_context.py b/src/agents/tool_context.py new file mode 100644 index 000000000..31b5efd6d --- /dev/null +++ b/src/agents/tool_context.py @@ -0,0 +1,26 @@ +from dataclasses import KW_ONLY, dataclass, fields +from typing import Any + +from .run_context import RunContextWrapper, TContext + + +@dataclass +class ToolContext(RunContextWrapper[TContext]): + """The context of a tool call.""" + + _: KW_ONLY + tool_call_id: str + """The ID of the tool call.""" + + @classmethod + def from_agent_context( + cls, context: RunContextWrapper[TContext], tool_call_id: str + ) -> "ToolContext": + """ + Create a ToolContext from a RunContextWrapper. + """ + # Grab the names of the RunContextWrapper's init=True fields + base_values: dict[str, Any] = { + f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init + } + return cls(tool_call_id=tool_call_id, **base_values) diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 0a57aea87..bc984bf13 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -7,6 +7,7 @@ from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool from agents.tool import default_tool_error_function +from agents.tool_context import ToolContext def argless_function() -> str: @@ -18,11 +19,11 @@ async def test_argless_function(): tool = function_tool(argless_function) assert tool.name == "argless_function" - result = await tool.on_invoke_tool(RunContextWrapper(None), "") + result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "") assert result == "ok" -def argless_with_context(ctx: RunContextWrapper[str]) -> str: +def argless_with_context(ctx: ToolContext[str]) -> str: return "ok" @@ -31,11 +32,11 @@ async def test_argless_with_context(): tool = function_tool(argless_with_context) assert tool.name == "argless_with_context" - result = await tool.on_invoke_tool(RunContextWrapper(None), "") + result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "") assert result == "ok" # Extra JSON should not raise an error - result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}') + result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}') assert result == "ok" @@ -48,15 +49,15 @@ async def test_simple_function(): tool = function_tool(simple_function, failure_error_function=None) assert tool.name == "simple_function" - result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}') + result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}') assert result == 6 - result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}') + result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}') assert result == 3 # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): - await tool.on_invoke_tool(RunContextWrapper(None), "") + await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "") class Foo(BaseModel): @@ -84,7 +85,7 @@ async def test_complex_args_function(): "bar": Bar(x="hello", y=10), } ) - result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) + result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) assert result == "6 hello10 hello" valid_json = json.dumps( @@ -93,7 +94,7 @@ async def test_complex_args_function(): "bar": Bar(x="hello", y=10), } ) - result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) + result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) assert result == "3 hello10 hello" valid_json = json.dumps( @@ -103,12 +104,12 @@ async def test_complex_args_function(): "baz": "world", } ) - result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json) + result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json) assert result == "3 hello10 world" # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): - await tool.on_invoke_tool(RunContextWrapper(None), '{"foo": {"a": 1}}') + await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}') def test_function_config_overrides(): @@ -168,7 +169,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert tool.params_json_schema[key] == value assert tool.strict_json_schema - result = await tool.on_invoke_tool(RunContextWrapper(None), '{"data": "hello"}') + result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}') assert result == "hello_done" tool_not_strict = FunctionTool( @@ -183,7 +184,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str: assert "additionalProperties" not in tool_not_strict.params_json_schema result = await tool_not_strict.on_invoke_tool( - RunContextWrapper(None), '{"data": "hello", "bar": "baz"}' + ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}' ) assert result == "hello_done" @@ -194,7 +195,7 @@ def my_func(a: int, b: int = 5): raise ValueError("test") tool = function_tool(my_func) - ctx = RunContextWrapper(None) + ctx = ToolContext(None, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert "Invalid JSON" in str(result) @@ -218,7 +219,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = RunContextWrapper(None) + ctx = ToolContext(None, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" @@ -242,7 +243,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) -> return f"error_{error.__class__.__name__}" tool = function_tool(my_func, failure_error_function=custom_sync_error_function) - ctx = RunContextWrapper(None) + ctx = ToolContext(None, tool_call_id="1") result = await tool.on_invoke_tool(ctx, "") assert result == "error_ModelBehaviorError" diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index 3b52788fb..d334d8f84 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -7,6 +7,7 @@ from agents import function_tool from agents.run_context import RunContextWrapper +from agents.tool_context import ToolContext class DummyContext: @@ -14,8 +15,8 @@ def __init__(self): self.data = "something" -def ctx_wrapper() -> RunContextWrapper[DummyContext]: - return RunContextWrapper(DummyContext()) +def ctx_wrapper() -> ToolContext[DummyContext]: + return ToolContext(context=DummyContext(), tool_call_id="1") @function_tool @@ -44,7 +45,7 @@ async def test_sync_no_context_with_args_invocation(): @function_tool -def sync_with_context(ctx: RunContextWrapper[DummyContext], name: str) -> str: +def sync_with_context(ctx: ToolContext[DummyContext], name: str) -> str: return f"{name}_{ctx.context.data}" @@ -71,7 +72,7 @@ async def test_async_no_context_invocation(): @function_tool -async def async_with_context(ctx: RunContextWrapper[DummyContext], prefix: str, num: int) -> str: +async def async_with_context(ctx: ToolContext[DummyContext], prefix: str, num: int) -> str: await asyncio.sleep(0) return f"{prefix}-{num}-{ctx.context.data}" diff --git a/tests/test_responses.py b/tests/test_responses.py index 6b91bf8c6..74212f61b 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -49,10 +49,12 @@ def _foo() -> str: ) -def get_function_tool_call(name: str, arguments: str | None = None) -> ResponseOutputItem: +def get_function_tool_call( + name: str, arguments: str | None = None, call_id: str | None = None +) -> ResponseOutputItem: return ResponseFunctionToolCall( id="1", - call_id="2", + call_id=call_id or "2", type="function_call", name=name, arguments=arguments or "", diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 6ae25fbd5..5ecff4a92 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import Any import pytest @@ -26,6 +27,8 @@ RunImpl, SingleStepResult, ) +from agents.tool import function_tool +from agents.tool_context import ToolContext from .test_responses import ( get_final_output_message, @@ -158,6 +161,42 @@ async def test_multiple_tool_calls(): assert isinstance(result.next_step, NextStepRunAgain) +@pytest.mark.asyncio +async def test_multiple_tool_calls_with_tool_context(): + async def _fake_tool(context: ToolContext[str], value: str) -> str: + return f"{value}-{context.tool_call_id}" + + tool = function_tool(_fake_tool, name_override="fake_tool", failure_error_function=None) + + agent = Agent( + name="test", + tools=[tool], + ) + response = ModelResponse( + output=[ + get_function_tool_call("fake_tool", json.dumps({"value": "123"}), call_id="1"), + get_function_tool_call("fake_tool", json.dumps({"value": "456"}), call_id="2"), + ], + usage=Usage(), + response_id=None, + ) + + result = await get_execute_result(agent, response) + assert result.original_input == "hello" + + # 4 items: new message, 2 tool calls, 2 tool call outputs + assert len(result.generated_items) == 4 + assert isinstance(result.next_step, NextStepRunAgain) + + items = result.generated_items + assert_item_is_function_tool_call(items[0], "fake_tool", json.dumps({"value": "123"})) + assert_item_is_function_tool_call(items[1], "fake_tool", json.dumps({"value": "456"})) + assert_item_is_function_tool_call_output(items[2], "123-1") + assert_item_is_function_tool_call_output(items[3], "456-2") + + assert isinstance(result.next_step, NextStepRunAgain) + + @pytest.mark.asyncio async def test_handoff_output_leads_to_handoff_next_step(): agent_1 = Agent(name="test_1")