From a89c2c02243f7431063651010f3b283307a8e59a Mon Sep 17 00:00:00 2001 From: Ray Myers Date: Thu, 30 Jan 2025 14:43:07 -0600 Subject: [PATCH] Add retry notification for main agent LLM --- openhands/server/session/session.py | 19 +++++++- tests/unit/test_session.py | 69 +++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_session.py diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index a7a16a4fa6ec..456429c40da8 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -104,7 +104,7 @@ async def initialize_agent( # TODO: override other LLM config & agent config groups (#2075) - llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls)) + llm = self._create_llm(agent_cls) agent_config = self.config.get_agent_config(agent_cls) if settings.enable_default_condenser: @@ -142,6 +142,21 @@ async def initialize_agent( ) return + def _create_llm(self, agent_cls: str | None) -> LLM: + """ + Initialize LLM, extracted for testing. + """ + return LLM( + config=self.config.get_llm_config_from_agent(agent_cls), + retry_listener=self._notify_on_llm_retry, + ) + + def _notify_on_llm_retry(self, retries: int, max: int) -> None: + msg_id = 'STATUS$LLM_RETRY' + self.queue_status_message( + 'info', msg_id, f'Retrying LLM request, {retries} / {max}' + ) + def on_event(self, event: Event): asyncio.get_event_loop().run_until_complete(self._on_event(event)) @@ -220,7 +235,7 @@ async def _send_status_message(self, msg_type: str, id: str, message: str): """Sends a status message to the client.""" if msg_type == 'error': await self.agent_session.stop_agent_loop_for_error() - + print('SNET') await self.send( {'status_update': True, 'type': msg_type, 'id': id, 'message': message} ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000000..7f61e66f01d0 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,69 @@ +from unittest.mock import ANY, AsyncMock, patch + +import pytest +from litellm.exceptions import ( + RateLimitError, +) + +from openhands.core.config.app_config import AppConfig +from openhands.core.config.llm_config import LLMConfig +from openhands.server.session.session import Session +from openhands.storage.memory import InMemoryFileStore + + +@pytest.fixture +def mock_status_callback(): + return AsyncMock() + + +@pytest.fixture +def mock_sio(): + return AsyncMock() + + +@pytest.fixture +def default_llm_config(): + return LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + +@pytest.mark.asyncio +@patch('openhands.llm.llm.litellm_completion') +async def test_notify_on_llm_retry( + mock_litellm_completion, mock_sio, default_llm_config +): + config = AppConfig() + config.set_llm_config(default_llm_config) + session = Session( + sid='..sid..', + file_store=InMemoryFileStore({}), + config=config, + sio=mock_sio, + user_id='..uid..', + ) + session.queue_status_message = AsyncMock() + + with patch('time.sleep') as _mock_sleep: + mock_litellm_completion.side_effect = [ + RateLimitError( + 'Rate limit exceeded', llm_provider='test_provider', model='test_model' + ), + {'choices': [{'message': {'content': 'Retry successful'}}]}, + ] + llm = session._create_llm('..cls..') + + llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert mock_litellm_completion.call_count == 2 + session.queue_status_message.assert_called_once_with( + 'info', 'STATUS$LLM_RETRY', ANY + ) + await session.close()