98
98
StreamEventData = Union ["MessageDeltaChunk" , "ThreadMessage" , ThreadRun , RunStep , str ]
99
99
100
100
101
+ def _has_errors_in_toolcalls_output (tool_outputs : List [Dict ]) -> bool :
102
+ """
103
+ Check if any tool output contains an error.
104
+
105
+ :param List[Dict] tool_outputs: A list of tool outputs to check.
106
+ :return: True if any output contains an error, False otherwise.
107
+ :rtype: bool
108
+ """
109
+ for tool_output in tool_outputs :
110
+ output = tool_output .get ("output" )
111
+ if isinstance (output , str ):
112
+ try :
113
+ output_json = json .loads (output )
114
+ if "error" in output_json :
115
+ return True
116
+ except json .JSONDecodeError :
117
+ continue
118
+ return False
119
+
120
+
101
121
def _filter_parameters (model_class : Type , parameters : Dict [str , Any ]) -> Dict [str , Any ]:
102
122
"""
103
123
Remove the parameters, non present in class public fields; return shallow copy of a dictionary.
@@ -734,7 +754,7 @@ def execute(self, tool_call: RequiredFunctionToolCall) -> Any:
734
754
try :
735
755
function , parsed_arguments = self ._get_func_and_args (tool_call )
736
756
return function (** parsed_arguments ) if parsed_arguments else function ()
737
- except TypeError as e :
757
+ except Exception as e : # pylint: disable=broad-exception-caught
738
758
error_message = f"Error executing function '{ tool_call .function .name } ': { e } "
739
759
logging .error (error_message )
740
760
# Return error message as JSON string back to agent in order to make possible self
@@ -745,13 +765,12 @@ def execute(self, tool_call: RequiredFunctionToolCall) -> Any:
745
765
class AsyncFunctionTool (BaseFunctionTool ):
746
766
747
767
async def execute (self , tool_call : RequiredFunctionToolCall ) -> Any : # pylint: disable=invalid-overridden-method
748
- function , parsed_arguments = self ._get_func_and_args (tool_call )
749
-
750
768
try :
769
+ function , parsed_arguments = self ._get_func_and_args (tool_call )
751
770
if inspect .iscoroutinefunction (function ):
752
771
return await function (** parsed_arguments ) if parsed_arguments else await function ()
753
772
return function (** parsed_arguments ) if parsed_arguments else function ()
754
- except TypeError as e :
773
+ except Exception as e : # pylint: disable=broad-exception-caught
755
774
error_message = f"Error executing function '{ tool_call .function .name } ': { e } "
756
775
logging .error (error_message )
757
776
# Return error message as JSON string back to agent in order to make possible self correction
@@ -1511,13 +1530,13 @@ class BaseAgentEventHandler(Iterator[T]):
1511
1530
1512
1531
def __init__ (self ) -> None :
1513
1532
self .response_iterator : Optional [Iterator [bytes ]] = None
1514
- self .submit_tool_outputs : Optional [Callable [[ThreadRun , "BaseAgentEventHandler[T]" ], None ]] = None
1533
+ self .submit_tool_outputs : Optional [Callable [[ThreadRun , "BaseAgentEventHandler[T]" , bool ], Any ]]
1515
1534
self .buffer : Optional [bytes ] = None
1516
1535
1517
1536
def initialize (
1518
1537
self ,
1519
1538
response_iterator : Iterator [bytes ],
1520
- submit_tool_outputs : Callable [[ThreadRun , "BaseAgentEventHandler[T]" ], None ],
1539
+ submit_tool_outputs : Callable [[ThreadRun , "BaseAgentEventHandler[T]" , bool ], Any ],
1521
1540
) -> None :
1522
1541
self .response_iterator = (
1523
1542
itertools .chain (self .response_iterator , response_iterator ) if self .response_iterator else response_iterator
@@ -1569,17 +1588,33 @@ def until_done(self) -> None:
1569
1588
1570
1589
1571
1590
class AsyncAgentEventHandler (BaseAsyncAgentEventHandler [Tuple [str , StreamEventData , Optional [EventFunctionReturnT ]]]):
1591
+ def __init__ (self ) -> None :
1592
+ super ().__init__ ()
1593
+ self ._max_retry = 10
1594
+ self .current_retry = 0
1595
+
1596
+ def set_max_retry (self , max_retry : int ) -> None :
1597
+ """
1598
+ Set the maximum number of retries for tool output submission.
1599
+
1600
+ :param int max_retry: The maximum number of retries.
1601
+ """
1602
+ self ._max_retry = max_retry
1572
1603
1573
1604
async def _process_event (self , event_data_str : str ) -> Tuple [str , StreamEventData , Optional [EventFunctionReturnT ]]:
1605
+
1574
1606
event_type , event_data_obj = _parse_event (event_data_str )
1575
1607
if (
1576
1608
isinstance (event_data_obj , ThreadRun )
1577
1609
and event_data_obj .status == "requires_action"
1578
1610
and isinstance (event_data_obj .required_action , SubmitToolOutputsAction )
1579
1611
):
1580
- await cast (Callable [[ThreadRun , "BaseAsyncAgentEventHandler" ], Awaitable [None ]], self .submit_tool_outputs )(
1581
- event_data_obj , self
1582
- )
1612
+ tool_output = await cast (
1613
+ Callable [[ThreadRun , "BaseAsyncAgentEventHandler" , bool ], Awaitable [Any ]], self .submit_tool_outputs
1614
+ )(event_data_obj , self , self .current_retry < self ._max_retry )
1615
+
1616
+ if _has_errors_in_toolcalls_output (tool_output ):
1617
+ self .current_retry += 1
1583
1618
1584
1619
func_rt : Optional [EventFunctionReturnT ] = None
1585
1620
try :
@@ -1682,6 +1717,18 @@ async def on_unhandled_event(
1682
1717
1683
1718
1684
1719
class AgentEventHandler (BaseAgentEventHandler [Tuple [str , StreamEventData , Optional [EventFunctionReturnT ]]]):
1720
+ def __init__ (self ) -> None :
1721
+ super ().__init__ ()
1722
+ self ._max_retry = 10
1723
+ self .current_retry = 0
1724
+
1725
+ def set_max_retry (self , max_retry : int ) -> None :
1726
+ """
1727
+ Set the maximum number of retries for tool output submission.
1728
+
1729
+ :param int max_retry: The maximum number of retries.
1730
+ """
1731
+ self ._max_retry = max_retry
1685
1732
1686
1733
def _process_event (self , event_data_str : str ) -> Tuple [str , StreamEventData , Optional [EventFunctionReturnT ]]:
1687
1734
@@ -1691,10 +1738,13 @@ def _process_event(self, event_data_str: str) -> Tuple[str, StreamEventData, Opt
1691
1738
and event_data_obj .status == "requires_action"
1692
1739
and isinstance (event_data_obj .required_action , SubmitToolOutputsAction )
1693
1740
):
1694
- cast (Callable [[ThreadRun , "BaseAgentEventHandler" ], Awaitable [ None ] ], self .submit_tool_outputs )(
1695
- event_data_obj , self
1741
+ tool_output = cast (Callable [[ThreadRun , "BaseAgentEventHandler" , bool ], Any ], self .submit_tool_outputs )(
1742
+ event_data_obj , self , self . current_retry < self . _max_retry
1696
1743
)
1697
1744
1745
+ if _has_errors_in_toolcalls_output (tool_output ):
1746
+ self .current_retry += 1
1747
+
1698
1748
func_rt : Optional [EventFunctionReturnT ] = None
1699
1749
try :
1700
1750
if isinstance (event_data_obj , MessageDeltaChunk ):
@@ -1792,15 +1842,15 @@ class AsyncAgentRunStream(Generic[BaseAsyncAgentEventHandlerT]):
1792
1842
def __init__ (
1793
1843
self ,
1794
1844
response_iterator : AsyncIterator [bytes ],
1795
- submit_tool_outputs : Callable [[ThreadRun , BaseAsyncAgentEventHandlerT ], Awaitable [None ]],
1845
+ submit_tool_outputs : Callable [[ThreadRun , BaseAsyncAgentEventHandlerT , bool ], Awaitable [Any ]],
1796
1846
event_handler : BaseAsyncAgentEventHandlerT ,
1797
1847
):
1798
1848
self .response_iterator = response_iterator
1799
1849
self .event_handler = event_handler
1800
1850
self .submit_tool_outputs = submit_tool_outputs
1801
1851
self .event_handler .initialize (
1802
1852
self .response_iterator ,
1803
- cast (Callable [[ThreadRun , BaseAsyncAgentEventHandler ], Awaitable [None ]], submit_tool_outputs ),
1853
+ cast (Callable [[ThreadRun , BaseAsyncAgentEventHandler ], Awaitable [Any ]], submit_tool_outputs ),
1804
1854
)
1805
1855
1806
1856
async def __aenter__ (self ):
@@ -1818,15 +1868,15 @@ class AgentRunStream(Generic[BaseAgentEventHandlerT]):
1818
1868
def __init__ (
1819
1869
self ,
1820
1870
response_iterator : Iterator [bytes ],
1821
- submit_tool_outputs : Callable [[ThreadRun , BaseAgentEventHandlerT ], None ],
1871
+ submit_tool_outputs : Callable [[ThreadRun , BaseAgentEventHandlerT , bool ], Any ],
1822
1872
event_handler : BaseAgentEventHandlerT ,
1823
1873
):
1824
1874
self .response_iterator = response_iterator
1825
1875
self .event_handler = event_handler
1826
1876
self .submit_tool_outputs = submit_tool_outputs
1827
1877
self .event_handler .initialize (
1828
1878
self .response_iterator ,
1829
- cast (Callable [[ThreadRun , BaseAgentEventHandler ], None ], submit_tool_outputs ),
1879
+ cast (Callable [[ThreadRun , BaseAgentEventHandler , bool ], Any ], submit_tool_outputs ),
1830
1880
)
1831
1881
1832
1882
def __enter__ (self ):
0 commit comments