Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve stuck detection in UI mode #5595

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ async def _handle_message_action(self, action: MessageAction) -> None:

def _reset(self) -> None:
"""Resets the agent controller"""
self.almost_stuck = 0

self._pending_action = None
self.agent.reset()

Expand Down Expand Up @@ -902,17 +902,20 @@ def _apply_conversation_window(self, events: list[Event]) -> list[Event]:

return kept_events

def _is_stuck(self) -> bool:
def _is_stuck(self, ui_mode: bool | None = None) -> bool:
"""Checks if the agent or its delegate is stuck in a loop.

Args:
ui_mode: Optional override for UI mode. If not provided, uses not self.headless_mode.

Returns:
bool: True if the agent is stuck, False otherwise.
"""
# check if delegate stuck
if self.delegate and self.delegate._is_stuck():
if self.delegate and self.delegate._is_stuck(ui_mode):
return True

return self._stuck_detector.is_stuck()
return self._stuck_detector.is_stuck(ui_mode if ui_mode is not None else not self.headless_mode)

def __repr__(self):
return (
Expand Down
2 changes: 1 addition & 1 deletion openhands/controller/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class State:
end_id: int = -1
# truncation_id tracks where to load history after context window truncation
truncation_id: int = -1
almost_stuck: int = 0

delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
# NOTE: This will never be used by the controller, but it can be used by different
# evaluation tasks to store extra data needed to track the progress/state of the task.
Expand Down
58 changes: 22 additions & 36 deletions openhands/controller/stuck.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,26 @@ class StuckDetector:
def __init__(self, state: State):
self.state = state

def is_stuck(self):
# filter out MessageAction with source='user' from history
def is_stuck(self, ui_mode: bool = False):
if ui_mode:
# In UI mode, only look at history after the last user message
last_user_msg_idx = -1
for i, event in enumerate(self.state.history):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
last_user_msg_idx = i

history_to_check = self.state.history[last_user_msg_idx + 1:]
else:
# In headless mode, look at all history
history_to_check = self.state.history

# Filter out user messages and null events
filtered_history = [
event
for event in self.state.history
for event in history_to_check
if not (
(isinstance(event, MessageAction) and event.source == EventSource.USER)
or
# there might be some NullAction or NullObservation in the history at least for now
isinstance(event, (NullAction, NullObservation))
or isinstance(event, (NullAction, NullObservation))
)
]

Expand Down Expand Up @@ -81,43 +91,19 @@ def _is_stuck_repeating_action_observation(self, last_actions, last_observations
# it takes 4 actions and 4 observations to detect a loop
# assert len(last_actions) == 4 and len(last_observations) == 4

# reset almost_stuck reminder
self.state.almost_stuck = 0

# almost stuck? if two actions, obs are the same, we're almost stuck
if len(last_actions) >= 2 and len(last_observations) >= 2:
# Check for a loop of 4 identical action-observation pairs
if len(last_actions) == 4 and len(last_observations) == 4:
actions_equal = all(
self._eq_no_pid(last_actions[0], action) for action in last_actions[:2]
self._eq_no_pid(last_actions[0], action) for action in last_actions
)
observations_equal = all(
self._eq_no_pid(last_observations[0], observation)
for observation in last_observations[:2]
for observation in last_observations
)

# the last two actions and obs are the same?
if actions_equal and observations_equal:
self.state.almost_stuck = 2

# the last three actions and observations are the same?
if len(last_actions) >= 3 and len(last_observations) >= 3:
if (
actions_equal
and observations_equal
and self._eq_no_pid(last_actions[0], last_actions[2])
and self._eq_no_pid(last_observations[0], last_observations[2])
):
self.state.almost_stuck = 1

if len(last_actions) == 4 and len(last_observations) == 4:
if (
actions_equal
and observations_equal
and self._eq_no_pid(last_actions[0], last_actions[3])
and self._eq_no_pid(last_observations[0], last_observations[3])
):
logger.warning('Action, Observation loop detected')
self.state.almost_stuck = 0
return True
logger.warning('Action, Observation loop detected')
return True

return False

Expand Down
105 changes: 65 additions & 40 deletions tests/unit/test_is_stuck.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,53 @@ def test_history_too_short(self, stuck_detector: StuckDetector):
# cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)

assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
assert stuck_detector.is_stuck(ui_mode=True) is False

def test_ui_mode_resets_after_user_message(self, stuck_detector: StuckDetector):
state = stuck_detector.state

# First add some actions that would be stuck in non-UI mode
for i in range(4):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(content='', command='ls', command_id=i)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)

# In non-UI mode, this should be stuck
assert stuck_detector.is_stuck(ui_mode=False) is True

# Add a user message
message_action = MessageAction(content='Hello', wait_for_response=False)
message_action._source = EventSource.USER
state.history.append(message_action)

# In UI mode, this should not be stuck because we ignore history before user message
assert stuck_detector.is_stuck(ui_mode=True) is False

# Add two more identical actions - still not stuck because we need at least 3
for i in range(2):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i + 4
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(content='', command='ls', command_id=i + 4)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)

assert stuck_detector.is_stuck(ui_mode=True) is False

# Add two more identical actions - now it should be stuck
for i in range(2):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i + 6
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(content='', command='ls', command_id=i + 6)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)

assert stuck_detector.is_stuck(ui_mode=True) is True

def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetector):
state = stuck_detector.state
Expand Down Expand Up @@ -148,8 +194,7 @@ def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetect
state.history.append(message_null_observation)
# 8 events

assert stuck_detector.is_stuck() is False
assert stuck_detector.state.almost_stuck == 2
assert stuck_detector.is_stuck(ui_mode=False) is False

cmd_action_3 = CmdRunAction(command='ls')
cmd_action_3._id = 3
Expand All @@ -160,20 +205,7 @@ def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetect
# 10 events

assert len(state.history) == 10
assert (
len(state.history) == 10
) # Adjusted since history is a list and the controller is not running

# FIXME are we still testing this without this test?
# assert (
# len(
# get_pairs_from_events(state.history)
# )
# == 5
# )

assert stuck_detector.is_stuck() is False
assert stuck_detector.state.almost_stuck == 1
assert stuck_detector.is_stuck(ui_mode=False) is False

cmd_action_4 = CmdRunAction(command='ls')
cmd_action_4._id = 4
Expand All @@ -184,16 +216,9 @@ def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetect
# 12 events

assert len(state.history) == 12
# assert (
# len(
# get_pairs_from_events(state.history)
# )
# == 6
# )

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.state.almost_stuck == 0
assert stuck_detector.is_stuck(ui_mode=False) is True
mock_warning.assert_called_once_with('Action, Observation loop detected')

def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
Expand Down Expand Up @@ -245,7 +270,7 @@ def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
# 12 events

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
mock_warning.assert_called_once_with(
'Action, ErrorObservation loop detected'
)
Expand All @@ -259,7 +284,7 @@ def test_is_stuck_invalid_syntax_error(self, stuck_detector: StuckDetector):
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True

def test_is_not_stuck_invalid_syntax_error_random_lines(
self, stuck_detector: StuckDetector
Expand All @@ -272,7 +297,7 @@ def test_is_not_stuck_invalid_syntax_error_random_lines(
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
self, stuck_detector: StuckDetector
Expand All @@ -286,7 +311,7 @@ def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
Expand All @@ -297,7 +322,7 @@ def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True

def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
Expand All @@ -308,7 +333,7 @@ def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
self, stuck_detector: StuckDetector
Expand All @@ -317,7 +342,7 @@ def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
self._impl_unterminated_string_error_events(state, random_line=True)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
self, stuck_detector: StuckDetector
Expand All @@ -328,7 +353,7 @@ def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_stuck_ipython_unterminated_string_error(
self, stuck_detector: StuckDetector
Expand All @@ -337,7 +362,7 @@ def test_is_stuck_ipython_unterminated_string_error(
self._impl_unterminated_string_error_events(state, random_line=False)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True

def test_is_not_stuck_ipython_syntax_error_not_at_end(
self, stuck_detector: StuckDetector
Expand Down Expand Up @@ -382,7 +407,7 @@ def test_is_not_stuck_ipython_syntax_error_not_at_end(
state.history.append(ipython_observation_4)

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
mock_warning.assert_not_called()

def test_is_stuck_repeating_action_observation_pattern(
Expand Down Expand Up @@ -451,7 +476,7 @@ def test_is_stuck_repeating_action_observation_pattern(
state.history.append(read_observation_3)

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
mock_warning.assert_called_once_with('Action, Observation pattern detected')

def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
Expand Down Expand Up @@ -517,7 +542,7 @@ def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
# read_observation_3._cause = read_action_3._id
state.history.append(read_observation_3)

assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_stuck_monologue(self, stuck_detector):
state = stuck_detector.state
Expand Down Expand Up @@ -547,7 +572,7 @@ def test_is_stuck_monologue(self, stuck_detector):
message_action_6._source = EventSource.AGENT
state.history.append(message_action_6)

assert stuck_detector.is_stuck()
assert stuck_detector.is_stuck(ui_mode=False)

# Add an observation event between the repeated message actions
cmd_output_observation = CmdOutputObservation(
Expand All @@ -567,7 +592,7 @@ def test_is_stuck_monologue(self, stuck_detector):
state.history.append(message_action_8)

with patch('logging.Logger.warning'):
assert not stuck_detector.is_stuck()
assert not stuck_detector.is_stuck(ui_mode=False)


class TestAgentController:
Expand All @@ -584,4 +609,4 @@ def controller(self):
def test_is_stuck_delegate_stuck(self, controller: AgentController):
controller.delegate = Mock()
controller.delegate._is_stuck.return_value = True
assert controller._is_stuck() is True
assert controller._is_stuck(ui_mode=False) is True
Loading