Skip to content

Commit 77921bc

Browse files
committed
max auto function calls
1 parent dcb5e57 commit 77921bc

File tree

8 files changed

+460
-44
lines changed

8 files changed

+460
-44
lines changed

sdk/ai/azure-ai-projects/azure/ai/projects/aio/operations/_patch.py

+62-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import logging
1414
import os
1515
import time
16+
import json
1617
from pathlib import Path
1718
from typing import (
1819
IO,
@@ -664,6 +665,7 @@ class AgentsOperations(AgentsOperationsGenerated):
664665
def __init__(self, *args, **kwargs) -> None:
665666
super().__init__(*args, **kwargs)
666667
self._function_tool = _models.AsyncFunctionTool(set())
668+
self._function_tool_max_retry = 10
667669

668670
# pylint: disable=arguments-differ
669671
@overload
@@ -1622,6 +1624,7 @@ async def create_and_process_run(
16221624
)
16231625

16241626
# Monitor and process the run status
1627+
current_retry = 0
16251628
while run.status in [
16261629
RunStatus.QUEUED,
16271630
RunStatus.IN_PROGRESS,
@@ -1643,6 +1646,17 @@ async def create_and_process_run(
16431646
toolset.add(self._function_tool)
16441647
tool_outputs = await toolset.execute_tool_calls(tool_calls)
16451648

1649+
if self._has_errors_in_toolcalls_output(tool_outputs):
1650+
if current_retry >= self._function_tool_max_retry:
1651+
logging.warning(
1652+
f"Tool outputs contain errors - reaching max retry {self._function_tool_max_retry}"
1653+
)
1654+
await self.cancel_run(thread_id=thread_id, run_id=run.id)
1655+
break
1656+
else:
1657+
logging.warning(f"Tool outputs contain errors - retrying")
1658+
current_retry += 1
1659+
16461660
logging.info("Tool outputs: %s", tool_outputs)
16471661
if tool_outputs:
16481662
await self.submit_tool_outputs_to_run(
@@ -1653,6 +1667,25 @@ async def create_and_process_run(
16531667

16541668
return run
16551669

1670+
def _has_errors_in_toolcalls_output(self, tool_outputs: List[Dict]) -> bool:
1671+
"""
1672+
Check if any tool output contains an error.
1673+
1674+
:param List[Dict] tool_outputs: A list of tool outputs to check.
1675+
:return: True if any output contains an error, False otherwise.
1676+
:rtype: bool
1677+
"""
1678+
for tool_output in tool_outputs:
1679+
output = tool_output.get("output")
1680+
if isinstance(output, str):
1681+
try:
1682+
output_json = json.loads(output)
1683+
if "error" in output_json:
1684+
return True
1685+
except json.JSONDecodeError:
1686+
continue
1687+
return False
1688+
16561689
@overload
16571690
async def create_stream(
16581691
self,
@@ -2090,6 +2123,8 @@ async def create_stream( # pyright: ignore[reportInconsistentOverload]
20902123

20912124
if not event_handler:
20922125
event_handler = cast(_models.BaseAsyncAgentEventHandlerT, _models.AsyncAgentEventHandler())
2126+
if isinstance(event_handler, _models.AsyncAgentEventHandler):
2127+
event_handler.set_max_retry(self._function_tool_max_retry)
20932128

20942129
return _models.AsyncAgentRunStream(response_iterator, self._handle_submit_tool_outputs, event_handler)
20952130

@@ -2320,13 +2355,14 @@ async def submit_tool_outputs_to_stream( # pyright: ignore[reportInconsistentOv
23202355
event_handler.initialize(response_iterator, self._handle_submit_tool_outputs)
23212356

23222357
async def _handle_submit_tool_outputs(
2323-
self, run: _models.ThreadRun, event_handler: _models.BaseAsyncAgentEventHandler
2324-
) -> None:
2358+
self, run: _models.ThreadRun, event_handler: _models.BaseAsyncAgentEventHandler, submit_with_error: bool
2359+
) -> Any:
2360+
tool_outputs = []
23252361
if isinstance(run.required_action, _models.SubmitToolOutputsAction):
23262362
tool_calls = run.required_action.submit_tool_outputs.tool_calls
23272363
if not tool_calls:
23282364
logger.debug("No tool calls to execute.")
2329-
return
2365+
return tool_outputs
23302366

23312367
# We need tool set only if we are executing local function. In case if
23322368
# the tool is azure_function we just need to wait when it will be finished.
@@ -2338,11 +2374,20 @@ async def _handle_submit_tool_outputs(
23382374
toolset.add(self._function_tool)
23392375
tool_outputs = await toolset.execute_tool_calls(tool_calls)
23402376

2377+
if self._has_errors_in_toolcalls_output(tool_outputs):
2378+
if submit_with_error:
2379+
logging.warning(f"Tool outputs contain errors - retrying")
2380+
else:
2381+
logging.warning(f"Tool outputs contain errors - reaching max retry limit")
2382+
await self.cancel_run(thread_id=run.thread_id, run_id=run.id)
2383+
return tool_outputs
2384+
23412385
logger.info("Tool outputs: %s", tool_outputs)
23422386
if tool_outputs:
23432387
await self.submit_tool_outputs_to_stream(
23442388
thread_id=run.thread_id, run_id=run.id, tool_outputs=tool_outputs, event_handler=event_handler
23452389
)
2390+
return tool_outputs
23462391

23472392
# pylint: disable=arguments-differ
23482393
@overload
@@ -3128,25 +3173,31 @@ async def delete_agent(self, agent_id: str, **kwargs: Any) -> _models.AgentDelet
31283173
return await super().delete_agent(agent_id, **kwargs)
31293174

31303175
@overload
3131-
def enable_auto_function_calls(self, *, functions: Set[Callable[..., Any]]) -> None:
3176+
def enable_auto_function_calls(self, *, functions: Set[Callable[..., Any]], max_retry: int = 10) -> None:
31323177
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
31333178
If this is not set, functions must be called manually.
3179+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3180+
function call or figure out the answer with its knowledge.
31343181
:keyword functions: A set of callable functions to be used as tools.
31353182
:type functions: Set[Callable[..., Any]]
31363183
"""
31373184

31383185
@overload
3139-
def enable_auto_function_calls(self, *, function_tool: _models.AsyncFunctionTool) -> None:
3186+
def enable_auto_function_calls(self, *, function_tool: _models.AsyncFunctionTool, max_retry: int = 10) -> None:
31403187
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
31413188
If this is not set, functions must be called manually.
3189+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3190+
function call or figure out the answer with its knowledge.
31423191
:keyword function_tool: An AsyncFunctionTool object representing the tool to be used.
31433192
:type function_tool: Optional[_models.AsyncFunctionTool]
31443193
"""
31453194

31463195
@overload
3147-
def enable_auto_function_calls(self, *, toolset: _models.AsyncToolSet) -> None:
3196+
def enable_auto_function_calls(self, *, toolset: _models.AsyncToolSet, max_retry: int = 10) -> None:
31483197
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
31493198
If this is not set, functions must be called manually.
3199+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3200+
function call or figure out the answer with its knowledge.
31503201
:keyword toolset: An AsyncToolSet object representing the set of tools to be used.
31513202
:type toolset: Optional[_models.AsyncToolSet]
31523203
"""
@@ -3157,9 +3208,12 @@ def enable_auto_function_calls(
31573208
functions: Optional[Set[Callable[..., Any]]] = None,
31583209
function_tool: Optional[_models.AsyncFunctionTool] = None,
31593210
toolset: Optional[_models.AsyncToolSet] = None,
3211+
max_retry: int = 10,
31603212
) -> None:
31613213
"""Enables tool calls to be executed automatically during create_and_process_run or streaming.
31623214
If this is not set, functions must be called manually.
3215+
If automatic function calls fail, the agents will receive error messages allowing it to retry with another
3216+
function call or figure out the answer with its knowledge.
31633217
:keyword functions: A set of callable functions to be used as tools.
31643218
:type functions: Set[Callable[..., Any]]
31653219
:keyword function_tool: An AsyncFunctionTool object representing the tool to be used.
@@ -3175,6 +3229,8 @@ def enable_auto_function_calls(
31753229
tool = toolset.get_tool(_models.AsyncFunctionTool)
31763230
self._function_tool = tool
31773231

3232+
self._function_tool_max_retry = max_retry
3233+
31783234

31793235
class _SyncCredentialWrapper(TokenCredential):
31803236
"""

sdk/ai/azure-ai-projects/azure/ai/projects/models/_patch.py

+65-15
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,26 @@
9898
StreamEventData = Union["MessageDeltaChunk", "ThreadMessage", ThreadRun, RunStep, str]
9999

100100

101+
def _has_errors_in_toolcalls_output(tool_outputs: List[Dict]) -> bool:
102+
"""
103+
Check if any tool output contains an error.
104+
105+
:param List[Dict] tool_outputs: A list of tool outputs to check.
106+
:return: True if any output contains an error, False otherwise.
107+
:rtype: bool
108+
"""
109+
for tool_output in tool_outputs:
110+
output = tool_output.get("output")
111+
if isinstance(output, str):
112+
try:
113+
output_json = json.loads(output)
114+
if "error" in output_json:
115+
return True
116+
except json.JSONDecodeError:
117+
continue
118+
return False
119+
120+
101121
def _filter_parameters(model_class: Type, parameters: Dict[str, Any]) -> Dict[str, Any]:
102122
"""
103123
Remove the parameters, non present in class public fields; return shallow copy of a dictionary.
@@ -734,7 +754,7 @@ def execute(self, tool_call: RequiredFunctionToolCall) -> Any:
734754
try:
735755
function, parsed_arguments = self._get_func_and_args(tool_call)
736756
return function(**parsed_arguments) if parsed_arguments else function()
737-
except TypeError as e:
757+
except Exception as e: # pylint: disable=broad-exception-caught
738758
error_message = f"Error executing function '{tool_call.function.name}': {e}"
739759
logging.error(error_message)
740760
# Return error message as JSON string back to agent in order to make possible self
@@ -745,13 +765,12 @@ def execute(self, tool_call: RequiredFunctionToolCall) -> Any:
745765
class AsyncFunctionTool(BaseFunctionTool):
746766

747767
async def execute(self, tool_call: RequiredFunctionToolCall) -> Any: # pylint: disable=invalid-overridden-method
748-
function, parsed_arguments = self._get_func_and_args(tool_call)
749-
750768
try:
769+
function, parsed_arguments = self._get_func_and_args(tool_call)
751770
if inspect.iscoroutinefunction(function):
752771
return await function(**parsed_arguments) if parsed_arguments else await function()
753772
return function(**parsed_arguments) if parsed_arguments else function()
754-
except TypeError as e:
773+
except Exception as e: # pylint: disable=broad-exception-caught
755774
error_message = f"Error executing function '{tool_call.function.name}': {e}"
756775
logging.error(error_message)
757776
# Return error message as JSON string back to agent in order to make possible self correction
@@ -1511,13 +1530,13 @@ class BaseAgentEventHandler(Iterator[T]):
15111530

15121531
def __init__(self) -> None:
15131532
self.response_iterator: Optional[Iterator[bytes]] = None
1514-
self.submit_tool_outputs: Optional[Callable[[ThreadRun, "BaseAgentEventHandler[T]"], None]] = None
1533+
self.submit_tool_outputs: Optional[Callable[[ThreadRun, "BaseAgentEventHandler[T]", bool], Any]]
15151534
self.buffer: Optional[bytes] = None
15161535

15171536
def initialize(
15181537
self,
15191538
response_iterator: Iterator[bytes],
1520-
submit_tool_outputs: Callable[[ThreadRun, "BaseAgentEventHandler[T]"], None],
1539+
submit_tool_outputs: Callable[[ThreadRun, "BaseAgentEventHandler[T]", bool], Any],
15211540
) -> None:
15221541
self.response_iterator = (
15231542
itertools.chain(self.response_iterator, response_iterator) if self.response_iterator else response_iterator
@@ -1569,17 +1588,33 @@ def until_done(self) -> None:
15691588

15701589

15711590
class AsyncAgentEventHandler(BaseAsyncAgentEventHandler[Tuple[str, StreamEventData, Optional[EventFunctionReturnT]]]):
1591+
def __init__(self) -> None:
1592+
super().__init__()
1593+
self._max_retry = 10
1594+
self.current_retry = 0
1595+
1596+
def set_max_retry(self, max_retry: int) -> None:
1597+
"""
1598+
Set the maximum number of retries for tool output submission.
1599+
1600+
:param int max_retry: The maximum number of retries.
1601+
"""
1602+
self._max_retry = max_retry
15721603

15731604
async def _process_event(self, event_data_str: str) -> Tuple[str, StreamEventData, Optional[EventFunctionReturnT]]:
1605+
15741606
event_type, event_data_obj = _parse_event(event_data_str)
15751607
if (
15761608
isinstance(event_data_obj, ThreadRun)
15771609
and event_data_obj.status == "requires_action"
15781610
and isinstance(event_data_obj.required_action, SubmitToolOutputsAction)
15791611
):
1580-
await cast(Callable[[ThreadRun, "BaseAsyncAgentEventHandler"], Awaitable[None]], self.submit_tool_outputs)(
1581-
event_data_obj, self
1582-
)
1612+
tool_output = await cast(
1613+
Callable[[ThreadRun, "BaseAsyncAgentEventHandler", bool], Awaitable[Any]], self.submit_tool_outputs
1614+
)(event_data_obj, self, self.current_retry < self._max_retry)
1615+
1616+
if _has_errors_in_toolcalls_output(tool_output):
1617+
self.current_retry += 1
15831618

15841619
func_rt: Optional[EventFunctionReturnT] = None
15851620
try:
@@ -1682,6 +1717,18 @@ async def on_unhandled_event(
16821717

16831718

16841719
class AgentEventHandler(BaseAgentEventHandler[Tuple[str, StreamEventData, Optional[EventFunctionReturnT]]]):
1720+
def __init__(self) -> None:
1721+
super().__init__()
1722+
self._max_retry = 10
1723+
self.current_retry = 0
1724+
1725+
def set_max_retry(self, max_retry: int) -> None:
1726+
"""
1727+
Set the maximum number of retries for tool output submission.
1728+
1729+
:param int max_retry: The maximum number of retries.
1730+
"""
1731+
self._max_retry = max_retry
16851732

16861733
def _process_event(self, event_data_str: str) -> Tuple[str, StreamEventData, Optional[EventFunctionReturnT]]:
16871734

@@ -1691,10 +1738,13 @@ def _process_event(self, event_data_str: str) -> Tuple[str, StreamEventData, Opt
16911738
and event_data_obj.status == "requires_action"
16921739
and isinstance(event_data_obj.required_action, SubmitToolOutputsAction)
16931740
):
1694-
cast(Callable[[ThreadRun, "BaseAgentEventHandler"], Awaitable[None]], self.submit_tool_outputs)(
1695-
event_data_obj, self
1741+
tool_output = cast(Callable[[ThreadRun, "BaseAgentEventHandler", bool], Any], self.submit_tool_outputs)(
1742+
event_data_obj, self, self.current_retry < self._max_retry
16961743
)
16971744

1745+
if _has_errors_in_toolcalls_output(tool_output):
1746+
self.current_retry += 1
1747+
16981748
func_rt: Optional[EventFunctionReturnT] = None
16991749
try:
17001750
if isinstance(event_data_obj, MessageDeltaChunk):
@@ -1792,15 +1842,15 @@ class AsyncAgentRunStream(Generic[BaseAsyncAgentEventHandlerT]):
17921842
def __init__(
17931843
self,
17941844
response_iterator: AsyncIterator[bytes],
1795-
submit_tool_outputs: Callable[[ThreadRun, BaseAsyncAgentEventHandlerT], Awaitable[None]],
1845+
submit_tool_outputs: Callable[[ThreadRun, BaseAsyncAgentEventHandlerT, bool], Awaitable[Any]],
17961846
event_handler: BaseAsyncAgentEventHandlerT,
17971847
):
17981848
self.response_iterator = response_iterator
17991849
self.event_handler = event_handler
18001850
self.submit_tool_outputs = submit_tool_outputs
18011851
self.event_handler.initialize(
18021852
self.response_iterator,
1803-
cast(Callable[[ThreadRun, BaseAsyncAgentEventHandler], Awaitable[None]], submit_tool_outputs),
1853+
cast(Callable[[ThreadRun, BaseAsyncAgentEventHandler], Awaitable[Any]], submit_tool_outputs),
18041854
)
18051855

18061856
async def __aenter__(self):
@@ -1818,15 +1868,15 @@ class AgentRunStream(Generic[BaseAgentEventHandlerT]):
18181868
def __init__(
18191869
self,
18201870
response_iterator: Iterator[bytes],
1821-
submit_tool_outputs: Callable[[ThreadRun, BaseAgentEventHandlerT], None],
1871+
submit_tool_outputs: Callable[[ThreadRun, BaseAgentEventHandlerT, bool], Any],
18221872
event_handler: BaseAgentEventHandlerT,
18231873
):
18241874
self.response_iterator = response_iterator
18251875
self.event_handler = event_handler
18261876
self.submit_tool_outputs = submit_tool_outputs
18271877
self.event_handler.initialize(
18281878
self.response_iterator,
1829-
cast(Callable[[ThreadRun, BaseAgentEventHandler], None], submit_tool_outputs),
1879+
cast(Callable[[ThreadRun, BaseAgentEventHandler, bool], Any], submit_tool_outputs),
18301880
)
18311881

18321882
def __enter__(self):

0 commit comments

Comments
 (0)