Skip to content

Commit

Permalink
fix: improve stuck detection in UI mode
Browse files Browse the repository at this point in the history
- Add UI mode awareness to stuck detection
- Only consider history after last user message in UI mode
- Keep existing behavior in headless mode
- Add comprehensive tests for both modes

Fix: #5480
  • Loading branch information
openhands-agent committed Dec 14, 2024
1 parent 212787c commit 1e739bd
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 28 deletions.
9 changes: 6 additions & 3 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,17 +902,20 @@ def _apply_conversation_window(self, events: list[Event]) -> list[Event]:

return kept_events

def _is_stuck(self) -> bool:
def _is_stuck(self, ui_mode: bool | None = None) -> bool:
"""Checks if the agent or its delegate is stuck in a loop.
Args:
ui_mode: Optional override for UI mode. If not provided, uses not self.headless_mode.
Returns:
bool: True if the agent is stuck, False otherwise.
"""
# check if delegate stuck
if self.delegate and self.delegate._is_stuck():
if self.delegate and self.delegate._is_stuck(ui_mode):
return True

return self._stuck_detector.is_stuck()
return self._stuck_detector.is_stuck(ui_mode if ui_mode is not None else not self.headless_mode)

def __repr__(self):
return (
Expand Down
22 changes: 16 additions & 6 deletions openhands/controller/stuck.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,26 @@ class StuckDetector:
def __init__(self, state: State):
self.state = state

def is_stuck(self):
# filter out MessageAction with source='user' from history
def is_stuck(self, ui_mode: bool = False):
if ui_mode:
# In UI mode, only look at history after the last user message
last_user_msg_idx = -1
for i, event in enumerate(self.state.history):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
last_user_msg_idx = i

history_to_check = self.state.history[last_user_msg_idx + 1:]
else:
# In headless mode, look at all history
history_to_check = self.state.history

# Filter out user messages and null events
filtered_history = [
event
for event in self.state.history
for event in history_to_check
if not (
(isinstance(event, MessageAction) and event.source == EventSource.USER)
or
# there might be some NullAction or NullObservation in the history at least for now
isinstance(event, (NullAction, NullObservation))
or isinstance(event, (NullAction, NullObservation))
)
]

Expand Down
84 changes: 65 additions & 19 deletions tests/unit/test_is_stuck.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,53 @@ def test_history_too_short(self, stuck_detector: StuckDetector):
# cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)

assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
assert stuck_detector.is_stuck(ui_mode=True) is False

def test_ui_mode_resets_after_user_message(self, stuck_detector: StuckDetector):
state = stuck_detector.state

# First add some actions that would be stuck in non-UI mode
for i in range(4):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(content='', command='ls', command_id=i)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)

# In non-UI mode, this should be stuck
assert stuck_detector.is_stuck(ui_mode=False) is True

# Add a user message
message_action = MessageAction(content='Hello', wait_for_response=False)
message_action._source = EventSource.USER
state.history.append(message_action)

# In UI mode, this should not be stuck because we ignore history before user message
assert stuck_detector.is_stuck(ui_mode=True) is False

# Add two more identical actions - still not stuck because we need at least 3
for i in range(2):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i + 4
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(content='', command='ls', command_id=i + 4)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)

assert stuck_detector.is_stuck(ui_mode=True) is False

# Add two more identical actions - now it should be stuck
for i in range(2):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i + 6
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(content='', command='ls', command_id=i + 6)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)

assert stuck_detector.is_stuck(ui_mode=True) is True

def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetector):
state = stuck_detector.state
Expand Down Expand Up @@ -148,7 +194,7 @@ def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetect
state.history.append(message_null_observation)
# 8 events

assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

cmd_action_3 = CmdRunAction(command='ls')
cmd_action_3._id = 3
Expand All @@ -159,7 +205,7 @@ def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetect
# 10 events

assert len(state.history) == 10
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

cmd_action_4 = CmdRunAction(command='ls')
cmd_action_4._id = 4
Expand All @@ -172,7 +218,7 @@ def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetect
assert len(state.history) == 12

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
mock_warning.assert_called_once_with('Action, Observation loop detected')

def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
Expand Down Expand Up @@ -224,7 +270,7 @@ def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
# 12 events

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
mock_warning.assert_called_once_with(
'Action, ErrorObservation loop detected'
)
Expand All @@ -238,7 +284,7 @@ def test_is_stuck_invalid_syntax_error(self, stuck_detector: StuckDetector):
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True

def test_is_not_stuck_invalid_syntax_error_random_lines(
self, stuck_detector: StuckDetector
Expand All @@ -251,7 +297,7 @@ def test_is_not_stuck_invalid_syntax_error_random_lines(
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
self, stuck_detector: StuckDetector
Expand All @@ -265,7 +311,7 @@ def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
Expand All @@ -276,7 +322,7 @@ def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True

def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
Expand All @@ -287,7 +333,7 @@ def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
self, stuck_detector: StuckDetector
Expand All @@ -296,7 +342,7 @@ def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
self._impl_unterminated_string_error_events(state, random_line=True)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
self, stuck_detector: StuckDetector
Expand All @@ -307,7 +353,7 @@ def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_stuck_ipython_unterminated_string_error(
self, stuck_detector: StuckDetector
Expand All @@ -316,7 +362,7 @@ def test_is_stuck_ipython_unterminated_string_error(
self._impl_unterminated_string_error_events(state, random_line=False)

with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True

def test_is_not_stuck_ipython_syntax_error_not_at_end(
self, stuck_detector: StuckDetector
Expand Down Expand Up @@ -361,7 +407,7 @@ def test_is_not_stuck_ipython_syntax_error_not_at_end(
state.history.append(ipython_observation_4)

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
mock_warning.assert_not_called()

def test_is_stuck_repeating_action_observation_pattern(
Expand Down Expand Up @@ -430,7 +476,7 @@ def test_is_stuck_repeating_action_observation_pattern(
state.history.append(read_observation_3)

with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
mock_warning.assert_called_once_with('Action, Observation pattern detected')

def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
Expand Down Expand Up @@ -496,7 +542,7 @@ def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
# read_observation_3._cause = read_action_3._id
state.history.append(read_observation_3)

assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False

def test_is_stuck_monologue(self, stuck_detector):
state = stuck_detector.state
Expand Down Expand Up @@ -526,7 +572,7 @@ def test_is_stuck_monologue(self, stuck_detector):
message_action_6._source = EventSource.AGENT
state.history.append(message_action_6)

assert stuck_detector.is_stuck()
assert stuck_detector.is_stuck(ui_mode=False)

# Add an observation event between the repeated message actions
cmd_output_observation = CmdOutputObservation(
Expand All @@ -546,7 +592,7 @@ def test_is_stuck_monologue(self, stuck_detector):
state.history.append(message_action_8)

with patch('logging.Logger.warning'):
assert not stuck_detector.is_stuck()
assert not stuck_detector.is_stuck(ui_mode=False)


class TestAgentController:
Expand All @@ -563,4 +609,4 @@ def controller(self):
def test_is_stuck_delegate_stuck(self, controller: AgentController):
controller.delegate = Mock()
controller.delegate._is_stuck.return_value = True
assert controller._is_stuck() is True
assert controller._is_stuck(ui_mode=False) is True

0 comments on commit 1e739bd

Please sign in to comment.