diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index d4fd9178b390..f2682a3da16a 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -983,10 +983,12 @@ def _is_stuck(self) -> bool: def __repr__(self): return ( - f'AgentController(id={self.id}, agent={self.agent!r}, ' - f'event_stream={self.event_stream!r}, ' - f'state={self.state!r}, ' - f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})' + f'AgentController(id={getattr(self, "id", "")}, ' + f'agent={getattr(self, "agent", "")!r}, ' + f'event_stream={getattr(self, "event_stream", "")!r}, ' + f'state={getattr(self, "state", "")!r}, ' + f'delegate={getattr(self, "delegate", "")!r}, ' + f'_pending_action={getattr(self, "_pending_action", "")!r})' ) def _is_awaiting_observation(self): diff --git a/tests/unit/test_agent_session.py b/tests/unit/test_agent_session.py new file mode 100644 index 000000000000..2619c7949804 --- /dev/null +++ b/tests/unit/test_agent_session.py @@ -0,0 +1,98 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openhands.controller.agent import Agent +from openhands.controller.agent_controller import AgentController +from openhands.core.config import AppConfig, LLMConfig +from openhands.events import EventStream, EventStreamSubscriber +from openhands.llm import LLM +from openhands.llm.metrics import Metrics +from openhands.runtime.base import Runtime +from openhands.server.session.agent_session import AgentSession +from openhands.storage.memory import InMemoryFileStore + + +@pytest.fixture +def mock_agent(): + """Create a properly configured mock agent with all required nested attributes""" + # Create the base mocks + agent = MagicMock(spec=Agent) + llm = MagicMock(spec=LLM) + metrics = MagicMock(spec=Metrics) + llm_config = MagicMock(spec=LLMConfig) + + # Configure the LLM config + llm_config.model = 'test-model' + llm_config.base_url = 'http://test' + llm_config.draft_editor = None + llm_config.max_message_chars = 1000 + + # Set up the chain of mocks + llm.metrics = metrics + llm.config = llm_config + agent.llm = llm + agent.name = 'test-agent' + agent.sandbox_plugins = [] + + return agent + + +@pytest.mark.asyncio +async def test_agent_session_start_calls_set_initial_state_once(mock_agent): + """Test that AgentSession.start() only calls set_initial_state once during controller initialization""" + + # Setup + file_store = InMemoryFileStore({}) + session = AgentSession(sid='test-session', file_store=file_store) + + # Create a mock runtime and set it up + mock_runtime = MagicMock(spec=Runtime) + + # Mock the runtime creation to set up the runtime attribute + async def mock_create_runtime(*args, **kwargs): + session.runtime = mock_runtime + + session._create_runtime = AsyncMock(side_effect=mock_create_runtime) + + # Create a mock EventStream + mock_event_stream = MagicMock(spec=EventStream) + mock_event_stream.get_events.return_value = [] + mock_event_stream.get_latest_event_id.return_value = 0 + mock_event_stream.subscribe = MagicMock() + + # Inject the mock event stream into the session + session.event_stream = mock_event_stream + + # Create a spy on set_initial_state by subclassing AgentController + class SpyAgentController(AgentController): + set_initial_state_call_count = 0 + + def set_initial_state(self, *args, **kwargs): + self.set_initial_state_call_count += 1 + super().set_initial_state(*args, **kwargs) + + # Patch AgentController with our spy version and inject our mock event stream + with patch( + 'openhands.server.session.agent_session.AgentController', SpyAgentController + ), patch( + 'openhands.server.session.agent_session.EventStream', + return_value=mock_event_stream, + ): + # Start the session + await session.start( + runtime_name='test-runtime', + config=AppConfig(), + agent=mock_agent, + max_iterations=10, + ) + + # Verify set_initial_state was called exactly once + assert session.controller.set_initial_state_call_count == 1 + + # Verify EventStream.subscribe was called with correct parameters + mock_event_stream.subscribe.assert_called_with( + EventStreamSubscriber.AGENT_CONTROLLER, + session.controller.on_event, + session.controller.id, + )