Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
enyst committed Jan 7, 2025
1 parent 1b07220 commit 87b1956
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 4 deletions.
10 changes: 6 additions & 4 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,10 +983,12 @@ def _is_stuck(self) -> bool:

def __repr__(self):
return (
f'AgentController(id={self.id}, agent={self.agent!r}, '
f'event_stream={self.event_stream!r}, '
f'state={self.state!r}, '
f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
f'AgentController(id={getattr(self, "id", "<uninitialized>")}, '
f'agent={getattr(self, "agent", "<uninitialized>")!r}, '
f'event_stream={getattr(self, "event_stream", "<uninitialized>")!r}, '
f'state={getattr(self, "state", "<uninitialized>")!r}, '
f'delegate={getattr(self, "delegate", "<uninitialized>")!r}, '
f'_pending_action={getattr(self, "_pending_action", "<uninitialized>")!r})'
)

def _is_awaiting_observation(self):
Expand Down
98 changes: 98 additions & 0 deletions tests/unit/test_agent_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.core.config import AppConfig, LLMConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.runtime.base import Runtime
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore


@pytest.fixture
def mock_agent():
"""Create a properly configured mock agent with all required nested attributes"""
# Create the base mocks
agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)

# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.draft_editor = None
llm_config.max_message_chars = 1000

# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []

return agent


@pytest.mark.asyncio
async def test_agent_session_start_calls_set_initial_state_once(mock_agent):
"""Test that AgentSession.start() only calls set_initial_state once during controller initialization"""

# Setup
file_store = InMemoryFileStore({})
session = AgentSession(sid='test-session', file_store=file_store)

# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)

# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime

session._create_runtime = AsyncMock(side_effect=mock_create_runtime)

# Create a mock EventStream
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.get_latest_event_id.return_value = 0
mock_event_stream.subscribe = MagicMock()

# Inject the mock event stream into the session
session.event_stream = mock_event_stream

# Create a spy on set_initial_state by subclassing AgentController
class SpyAgentController(AgentController):
set_initial_state_call_count = 0

def set_initial_state(self, *args, **kwargs):
self.set_initial_state_call_count += 1
super().set_initial_state(*args, **kwargs)

# Patch AgentController with our spy version and inject our mock event stream
with patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
), patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
):
# Start the session
await session.start(
runtime_name='test-runtime',
config=AppConfig(),
agent=mock_agent,
max_iterations=10,
)

# Verify set_initial_state was called exactly once
assert session.controller.set_initial_state_call_count == 1

# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)

0 comments on commit 87b1956

Please sign in to comment.