Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate state initialization #6089

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,10 +983,12 @@ def _is_stuck(self) -> bool:

def __repr__(self):
return (
f'AgentController(id={self.id}, agent={self.agent!r}, '
f'event_stream={self.event_stream!r}, '
f'state={self.state!r}, '
f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
f'AgentController(id={getattr(self, "id", "<uninitialized>")}, '
f'agent={getattr(self, "agent", "<uninitialized>")!r}, '
f'event_stream={getattr(self, "event_stream", "<uninitialized>")!r}, '
f'state={getattr(self, "state", "<uninitialized>")!r}, '
f'delegate={getattr(self, "delegate", "<uninitialized>")!r}, '
f'_pending_action={getattr(self, "_pending_action", "<uninitialized>")!r})'
)

def _is_awaiting_observation(self):
Expand Down
26 changes: 12 additions & 14 deletions openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,27 +268,25 @@ def _create_controller(
confirmation_mode=confirmation_mode,
headless_mode=False,
status_callback=self._status_callback,
initial_state=self._maybe_restore_state(),
)

# Note: We now attempt to restore the state from session here,
# but if it fails, we fall back to None and still initialize the controller
# with a fresh state. That way, the controller will always load events from the event stream
# even if the state file was corrupt.
return controller

def _maybe_restore_state(self) -> State | None:
"""Helper method to handle state restore logic."""
restored_state = None

# Attempt to restore the state from session.
# Use a heuristic to figure out if we should have a state:
# if we have events in the stream.
try:
restored_state = State.restore_from_session(self.sid, self.file_store)
logger.debug(f'Restored state from session, sid: {self.sid}')
except Exception as e:
if self.event_stream.get_latest_event_id() > 0:
# if we have events, we should have a state
logger.warning(f'State could not be restored: {e}')

# Set the initial state through the controller.
controller.set_initial_state(restored_state, max_iterations, confirmation_mode)
if restored_state:
logger.debug(f'Restored agent state from session, sid: {self.sid}')
else:
logger.debug('New session state created.')

logger.debug('Agent controller initialized.')
return controller
else:
logger.debug('No events found, no state to restore')
return restored_state
186 changes: 186 additions & 0 deletions tests/unit/test_agent_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import AppConfig, LLMConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.runtime.base import Runtime
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore


@pytest.fixture
def mock_agent():
"""Create a properly configured mock agent with all required nested attributes"""
# Create the base mocks
agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)

# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.draft_editor = None
llm_config.max_message_chars = 1000

# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []

return agent


@pytest.mark.asyncio
async def test_agent_session_start_with_no_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's no state to restore"""

# Setup
file_store = InMemoryFileStore({})
session = AgentSession(sid='test-session', file_store=file_store)

# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)

# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime

session._create_runtime = AsyncMock(side_effect=mock_create_runtime)

# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0

# Inject the mock event stream into the session
session.event_stream = mock_event_stream

# Create a spy on set_initial_state
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None

def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)

# Patch AgentController and State.restore_from_session to fail
with patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
), patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
), patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
):
await session.start(
runtime_name='test-runtime',
config=AppConfig(),
agent=mock_agent,
max_iterations=10,
)

# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)

# Verify set_initial_state was called once with None as state
assert session.controller.set_initial_state_call_count == 1
assert session.controller.test_initial_state is None
assert session.controller.state.max_iterations == 10
assert session.controller.agent.name == 'test-agent'
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
assert session.controller.state.truncation_id == -1


@pytest.mark.asyncio
async def test_agent_session_start_with_restored_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's a state to restore"""

# Setup
file_store = InMemoryFileStore({})
session = AgentSession(sid='test-session', file_store=file_store)

# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)

# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime

session._create_runtime = AsyncMock(side_effect=mock_create_runtime)

# Create a mock EventStream with some events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist

# Inject the mock event stream into the session
session.event_stream = mock_event_stream

# Create a mock restored state
mock_restored_state = MagicMock(spec=State)
mock_restored_state.start_id = -1
mock_restored_state.end_id = -1
mock_restored_state.truncation_id = -1
mock_restored_state.max_iterations = 5

# Create a spy on set_initial_state by subclassing AgentController
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None

def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)

# Patch AgentController and State.restore_from_session to succeed
with patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
), patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
), patch(
'openhands.controller.state.state.State.restore_from_session',
return_value=mock_restored_state,
):
await session.start(
runtime_name='test-runtime',
config=AppConfig(),
agent=mock_agent,
max_iterations=10,
)

# Verify set_initial_state was called once with the restored state
assert session.controller.set_initial_state_call_count == 1

# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
assert session.controller.test_initial_state is mock_restored_state
assert session.controller.state is mock_restored_state
assert session.controller.state.max_iterations == 5
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
assert session.controller.state.truncation_id == -1
Loading