-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
104 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from unittest.mock import AsyncMock, MagicMock, patch | ||
|
||
import pytest | ||
|
||
from openhands.controller.agent import Agent | ||
from openhands.controller.agent_controller import AgentController | ||
from openhands.core.config import AppConfig, LLMConfig | ||
from openhands.events import EventStream, EventStreamSubscriber | ||
from openhands.llm import LLM | ||
from openhands.llm.metrics import Metrics | ||
from openhands.runtime.base import Runtime | ||
from openhands.server.session.agent_session import AgentSession | ||
from openhands.storage.memory import InMemoryFileStore | ||
|
||
|
||
@pytest.fixture | ||
def mock_agent(): | ||
"""Create a properly configured mock agent with all required nested attributes""" | ||
# Create the base mocks | ||
agent = MagicMock(spec=Agent) | ||
llm = MagicMock(spec=LLM) | ||
metrics = MagicMock(spec=Metrics) | ||
llm_config = MagicMock(spec=LLMConfig) | ||
|
||
# Configure the LLM config | ||
llm_config.model = 'test-model' | ||
llm_config.base_url = 'http://test' | ||
llm_config.draft_editor = None | ||
llm_config.max_message_chars = 1000 | ||
|
||
# Set up the chain of mocks | ||
llm.metrics = metrics | ||
llm.config = llm_config | ||
agent.llm = llm | ||
agent.name = 'test-agent' | ||
agent.sandbox_plugins = [] | ||
|
||
return agent | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_agent_session_start_calls_set_initial_state_once(mock_agent): | ||
"""Test that AgentSession.start() only calls set_initial_state once during controller initialization""" | ||
|
||
# Setup | ||
file_store = InMemoryFileStore({}) | ||
session = AgentSession(sid='test-session', file_store=file_store) | ||
|
||
# Create a mock runtime and set it up | ||
mock_runtime = MagicMock(spec=Runtime) | ||
|
||
# Mock the runtime creation to set up the runtime attribute | ||
async def mock_create_runtime(*args, **kwargs): | ||
session.runtime = mock_runtime | ||
|
||
session._create_runtime = AsyncMock(side_effect=mock_create_runtime) | ||
|
||
# Create a mock EventStream | ||
mock_event_stream = MagicMock(spec=EventStream) | ||
mock_event_stream.get_events.return_value = [] | ||
mock_event_stream.get_latest_event_id.return_value = 0 | ||
mock_event_stream.subscribe = MagicMock() | ||
|
||
# Inject the mock event stream into the session | ||
session.event_stream = mock_event_stream | ||
|
||
# Create a spy on set_initial_state by subclassing AgentController | ||
class SpyAgentController(AgentController): | ||
set_initial_state_call_count = 0 | ||
|
||
def set_initial_state(self, *args, **kwargs): | ||
self.set_initial_state_call_count += 1 | ||
super().set_initial_state(*args, **kwargs) | ||
|
||
# Patch AgentController with our spy version and inject our mock event stream | ||
with patch( | ||
'openhands.server.session.agent_session.AgentController', SpyAgentController | ||
), patch( | ||
'openhands.server.session.agent_session.EventStream', | ||
return_value=mock_event_stream, | ||
): | ||
# Start the session | ||
await session.start( | ||
runtime_name='test-runtime', | ||
config=AppConfig(), | ||
agent=mock_agent, | ||
max_iterations=10, | ||
) | ||
|
||
# Verify set_initial_state was called exactly once | ||
assert session.controller.set_initial_state_call_count == 1 | ||
|
||
# Verify EventStream.subscribe was called with correct parameters | ||
mock_event_stream.subscribe.assert_called_with( | ||
EventStreamSubscriber.AGENT_CONTROLLER, | ||
session.controller.on_event, | ||
session.controller.id, | ||
) |