Skip to content

Commit

Permalink
Reset a failed tool call (#5666)
Browse files Browse the repository at this point in the history
Co-authored-by: openhands <[email protected]>
  • Loading branch information
enyst and openhands-agent authored Dec 31, 2024
1 parent 7ae1f76 commit a2e9e20
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 14 deletions.
13 changes: 1 addition & 12 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,18 +482,7 @@ def _get_messages(self, state: State) -> list[Message]:
if message:
if message.role == 'user':
self.prompt_manager.enhance_message(message)
# handle error if the message is the SAME role as the previous message
# litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'}
# there shouldn't be two consecutive messages from the same role
# NOTE: we shouldn't combine tool messages because each of them has a different tool_call_id
if (
messages
and messages[-1].role == message.role
and message.role != 'tool'
):
messages[-1].content.extend(message.content)
else:
messages.append(message)
messages.append(message)

if self.llm.is_caching_prompt_active():
# NOTE: this is only needed for anthropic
Expand Down
22 changes: 22 additions & 0 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,28 @@ async def _handle_message_action(self, action: MessageAction) -> None:
def _reset(self) -> None:
"""Resets the agent controller"""

# make sure there is an Observation with the tool call metadata to be recognized by the agent
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
# find out if there already is an observation with the same tool call metadata
found_observation = False
for event in self.state.history:
if (
isinstance(event, Observation)
and event.tool_call_metadata
== self._pending_action.tool_call_metadata
):
found_observation = True
break

# make a new ErrorObservation with the tool call metadata
if not found_observation:
obs = ErrorObservation(content='The action has not been executed.')
obs.tool_call_metadata = self._pending_action.tool_call_metadata
obs._cause = self._pending_action.id # type: ignore[attr-defined]
self.event_stream.add_event(obs, EventSource.AGENT)

# reset the pending action, this will be called when the agent is STOPPED or ERROR
self._pending_action = None
self.agent.reset()

Expand Down
6 changes: 4 additions & 2 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
warnings.simplefilter('ignore')
import litellm

from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import Message as LiteLLMMessage
from litellm import ModelInfo, PromptTokensDetails
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
Expand Down Expand Up @@ -246,7 +246,9 @@ def wrapper(*args, **kwargs):
resp.choices[0].message = fn_call_response_message

message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
'message'
].get('tool_calls', [])
if tool_calls:
for tool_call in tool_calls:
fn_name = tool_call.function.name
Expand Down
149 changes: 149 additions & 0 deletions tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,152 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream):
# In headless mode, throttling results in an error
assert controller.state.agent_state == AgentState.ERROR
await controller.close()


@pytest.mark.asyncio
async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream):
"""Test reset() when there's a pending action with tool call metadata but no observation."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)

# Create a pending action with tool call metadata
pending_action = CmdRunAction(command='test')
pending_action.tool_call_metadata = {
'function': 'test_function',
'args': {'arg1': 'value1'},
}
controller._pending_action = pending_action

# Call reset
controller._reset()

# Verify that an ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_called_once()
args, kwargs = mock_event_stream.add_event.call_args
error_obs, source = args
assert isinstance(error_obs, ErrorObservation)
assert error_obs.content == 'The action has not been executed.'
assert error_obs.tool_call_metadata == pending_action.tool_call_metadata
assert error_obs._cause == pending_action.id
assert source == EventSource.AGENT

# Verify that pending action was reset
assert controller._pending_action is None

# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()


@pytest.mark.asyncio
async def test_reset_with_pending_action_existing_observation(
mock_agent, mock_event_stream
):
"""Test reset() when there's a pending action with tool call metadata and an existing observation."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)

# Create a pending action with tool call metadata
pending_action = CmdRunAction(command='test')
pending_action.tool_call_metadata = {
'function': 'test_function',
'args': {'arg1': 'value1'},
}
controller._pending_action = pending_action

# Add an existing observation to the history
existing_obs = ErrorObservation(content='Previous error')
existing_obs.tool_call_metadata = pending_action.tool_call_metadata
controller.state.history.append(existing_obs)

# Call reset
controller._reset()

# Verify that no new ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_not_called()

# Verify that pending action was reset
assert controller._pending_action is None

# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()


@pytest.mark.asyncio
async def test_reset_without_pending_action(mock_agent, mock_event_stream):
"""Test reset() when there's no pending action."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)

# Call reset
controller._reset()

# Verify that no ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_not_called()

# Verify that pending action is None
assert controller._pending_action is None

# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()


@pytest.mark.asyncio
async def test_reset_with_pending_action_no_metadata(
mock_agent, mock_event_stream, monkeypatch
):
"""Test reset() when there's a pending action without tool call metadata."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)

# Create a pending action without tool call metadata
pending_action = CmdRunAction(command='test')
# Mock hasattr to return False for tool_call_metadata
original_hasattr = hasattr

def mock_hasattr(obj, name):
if obj == pending_action and name == 'tool_call_metadata':
return False
return original_hasattr(obj, name)

monkeypatch.setattr('builtins.hasattr', mock_hasattr)
controller._pending_action = pending_action

# Call reset
controller._reset()

# Verify that no ErrorObservation was added to the event stream
mock_event_stream.add_event.assert_not_called()

# Verify that pending action was reset
assert controller._pending_action is None

# Verify that agent.reset() was called
mock_agent.reset.assert_called_once()
await controller.close()

0 comments on commit a2e9e20

Please sign in to comment.