From a2e9e206e8eaf4527fc6a366979b1c5e728e8844 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Tue, 31 Dec 2024 21:21:32 +0100 Subject: [PATCH] Reset a failed tool call (#5666) Co-authored-by: openhands --- .../agenthub/codeact_agent/codeact_agent.py | 13 +- openhands/controller/agent_controller.py | 22 +++ openhands/llm/llm.py | 6 +- tests/unit/test_agent_controller.py | 149 ++++++++++++++++++ 4 files changed, 176 insertions(+), 14 deletions(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 7a2e0fc62b79..03fa8cc4dd30 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -482,18 +482,7 @@ def _get_messages(self, state: State) -> list[Message]: if message: if message.role == 'user': self.prompt_manager.enhance_message(message) - # handle error if the message is the SAME role as the previous message - # litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'} - # there shouldn't be two consecutive messages from the same role - # NOTE: we shouldn't combine tool messages because each of them has a different tool_call_id - if ( - messages - and messages[-1].role == message.role - and message.role != 'tool' - ): - messages[-1].content.extend(message.content) - else: - messages.append(message) + messages.append(message) if self.llm.is_caching_prompt_active(): # NOTE: this is only needed for anthropic diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index a6b666f13690..c88f598516f1 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -335,6 +335,28 @@ async def _handle_message_action(self, action: MessageAction) -> None: def _reset(self) -> None: """Resets the agent controller""" + # make sure there is an Observation with the tool call metadata to be recognized by the agent + # otherwise the pending action is found in history, but it's incomplete without an obs with tool result + if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'): + # find out if there already is an observation with the same tool call metadata + found_observation = False + for event in self.state.history: + if ( + isinstance(event, Observation) + and event.tool_call_metadata + == self._pending_action.tool_call_metadata + ): + found_observation = True + break + + # make a new ErrorObservation with the tool call metadata + if not found_observation: + obs = ErrorObservation(content='The action has not been executed.') + obs.tool_call_metadata = self._pending_action.tool_call_metadata + obs._cause = self._pending_action.id # type: ignore[attr-defined] + self.event_stream.add_event(obs, EventSource.AGENT) + + # reset the pending action, this will be called when the agent is STOPPED or ERROR self._pending_action = None self.agent.reset() diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index b5e6ac824159..13d4dfc25047 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -13,8 +13,8 @@ warnings.simplefilter('ignore') import litellm +from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails from litellm import Message as LiteLLMMessage -from litellm import ModelInfo, PromptTokensDetails from litellm import completion as litellm_completion from litellm import completion_cost as litellm_completion_cost from litellm.exceptions import ( @@ -246,7 +246,9 @@ def wrapper(*args, **kwargs): resp.choices[0].message = fn_call_response_message message_back: str = resp['choices'][0]['message']['content'] or '' - tool_calls = resp['choices'][0]['message'].get('tool_calls', []) + tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][ + 'message' + ].get('tool_calls', []) if tool_calls: for tool_call in tool_calls: fn_name = tool_call.function.name diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index d6927e3061b8..6d79645c278c 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -387,3 +387,152 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream): # In headless mode, throttling results in an error assert controller.state.agent_state == AgentState.ERROR await controller.close() + + +@pytest.mark.asyncio +async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream): + """Test reset() when there's a pending action with tool call metadata but no observation.""" + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + + # Create a pending action with tool call metadata + pending_action = CmdRunAction(command='test') + pending_action.tool_call_metadata = { + 'function': 'test_function', + 'args': {'arg1': 'value1'}, + } + controller._pending_action = pending_action + + # Call reset + controller._reset() + + # Verify that an ErrorObservation was added to the event stream + mock_event_stream.add_event.assert_called_once() + args, kwargs = mock_event_stream.add_event.call_args + error_obs, source = args + assert isinstance(error_obs, ErrorObservation) + assert error_obs.content == 'The action has not been executed.' + assert error_obs.tool_call_metadata == pending_action.tool_call_metadata + assert error_obs._cause == pending_action.id + assert source == EventSource.AGENT + + # Verify that pending action was reset + assert controller._pending_action is None + + # Verify that agent.reset() was called + mock_agent.reset.assert_called_once() + await controller.close() + + +@pytest.mark.asyncio +async def test_reset_with_pending_action_existing_observation( + mock_agent, mock_event_stream +): + """Test reset() when there's a pending action with tool call metadata and an existing observation.""" + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + + # Create a pending action with tool call metadata + pending_action = CmdRunAction(command='test') + pending_action.tool_call_metadata = { + 'function': 'test_function', + 'args': {'arg1': 'value1'}, + } + controller._pending_action = pending_action + + # Add an existing observation to the history + existing_obs = ErrorObservation(content='Previous error') + existing_obs.tool_call_metadata = pending_action.tool_call_metadata + controller.state.history.append(existing_obs) + + # Call reset + controller._reset() + + # Verify that no new ErrorObservation was added to the event stream + mock_event_stream.add_event.assert_not_called() + + # Verify that pending action was reset + assert controller._pending_action is None + + # Verify that agent.reset() was called + mock_agent.reset.assert_called_once() + await controller.close() + + +@pytest.mark.asyncio +async def test_reset_without_pending_action(mock_agent, mock_event_stream): + """Test reset() when there's no pending action.""" + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + + # Call reset + controller._reset() + + # Verify that no ErrorObservation was added to the event stream + mock_event_stream.add_event.assert_not_called() + + # Verify that pending action is None + assert controller._pending_action is None + + # Verify that agent.reset() was called + mock_agent.reset.assert_called_once() + await controller.close() + + +@pytest.mark.asyncio +async def test_reset_with_pending_action_no_metadata( + mock_agent, mock_event_stream, monkeypatch +): + """Test reset() when there's a pending action without tool call metadata.""" + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + + # Create a pending action without tool call metadata + pending_action = CmdRunAction(command='test') + # Mock hasattr to return False for tool_call_metadata + original_hasattr = hasattr + + def mock_hasattr(obj, name): + if obj == pending_action and name == 'tool_call_metadata': + return False + return original_hasattr(obj, name) + + monkeypatch.setattr('builtins.hasattr', mock_hasattr) + controller._pending_action = pending_action + + # Call reset + controller._reset() + + # Verify that no ErrorObservation was added to the event stream + mock_event_stream.add_event.assert_not_called() + + # Verify that pending action was reset + assert controller._pending_action is None + + # Verify that agent.reset() was called + mock_agent.reset.assert_called_once() + await controller.close()