From fd73f4210ef38f678e328a9976053386e99cb30d Mon Sep 17 00:00:00 2001 From: Ray Myers Date: Thu, 30 Jan 2025 15:51:47 -0600 Subject: [PATCH] Show LLM retries and allow resume from rate-limit state (#6438) Co-authored-by: Engel Nyst --- .../features/chat/chat-interface.tsx | 3 +- frontend/src/i18n/translation.json | 3 + openhands/controller/agent_controller.py | 15 +++- openhands/llm/llm.py | 6 +- openhands/llm/retry_mixin.py | 8 ++- openhands/server/session/session.py | 18 ++++- tests/unit/test_agent_controller.py | 18 ++++- tests/unit/test_session.py | 69 +++++++++++++++++++ 8 files changed, 129 insertions(+), 11 deletions(-) create mode 100644 tests/unit/test_session.py diff --git a/frontend/src/components/features/chat/chat-interface.tsx b/frontend/src/components/features/chat/chat-interface.tsx index bb49411b3302..036ead80428f 100644 --- a/frontend/src/components/features/chat/chat-interface.tsx +++ b/frontend/src/components/features/chat/chat-interface.tsx @@ -180,8 +180,7 @@ export function ChatInterface() { onStop={handleStop} isDisabled={ curAgentState === AgentState.LOADING || - curAgentState === AgentState.AWAITING_USER_CONFIRMATION || - curAgentState === AgentState.RATE_LIMITED + curAgentState === AgentState.AWAITING_USER_CONFIRMATION } mode={curAgentState === AgentState.RUNNING ? "stop" : "submit"} value={messageToSend ?? undefined} diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index 615e6f0ef1fc..a2a75b6f607b 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -3816,6 +3816,9 @@ "es": "Hubo un error al conectar con el entorno de ejecución. Por favor, actualice la página.", "tr": "Çalışma zamanına bağlanırken bir hata oluştu. Lütfen sayfayı yenileyin." }, + "STATUS$LLM_RETRY": { + "en": "Retrying LLM request" + }, "AGENT_ERROR$BAD_ACTION": { "en": "Agent tried to execute a malformed action.", "zh-CN": "错误的操作", diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 53bb44900e24..a2b70078fce0 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -235,8 +235,10 @@ async def _step_with_exception_handling(self): f'report this error to the developers. Your session ID is {self.id}. ' f'Error type: {e.__class__.__name__}' ) - if isinstance(e, litellm.AuthenticationError) or isinstance( - e, litellm.BadRequestError + if ( + isinstance(e, litellm.AuthenticationError) + or isinstance(e, litellm.BadRequestError) + or isinstance(e, RateLimitError) ): reported = e await self._react_to_exception(reported) @@ -530,7 +532,7 @@ async def start_delegate(self, action: AgentDelegateAction) -> None: agent_cls: Type[Agent] = Agent.get_cls(action.agent) agent_config = self.agent_configs.get(action.agent, self.agent.config) llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config) - llm = LLM(config=llm_config) + llm = LLM(config=llm_config, retry_listener=self._notify_on_llm_retry) delegate_agent = agent_cls(llm=llm, config=agent_config) state = State( inputs=action.inputs or {}, @@ -725,6 +727,13 @@ async def _step(self) -> None: log_level = 'info' if LOG_ALL_EVENTS else 'debug' self.log(log_level, str(action), extra={'msg_type': 'ACTION'}) + def _notify_on_llm_retry(self, retries: int, max: int) -> None: + if self.status_callback is not None: + msg_id = 'STATUS$LLM_RETRY' + self.status_callback( + 'info', msg_id, f'Retrying LLM request, {retries} / {max}' + ) + async def _handle_traffic_control( self, limit_type: str, current_value: float, max_value: float ) -> bool: diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 98bcf7cb173d..af25baded4c4 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -3,7 +3,7 @@ import time import warnings from functools import partial -from typing import Any +from typing import Any, Callable import requests @@ -94,6 +94,7 @@ def __init__( self, config: LLMConfig, metrics: Metrics | None = None, + retry_listener: Callable[[int, int], None] | None = None, ): """Initializes the LLM. If LLMConfig is passed, its values will be the fallback. @@ -111,7 +112,7 @@ def __init__( self.config: LLMConfig = copy.deepcopy(config) self.model_info: ModelInfo | None = None - + self.retry_listener = retry_listener if self.config.log_completions: if self.config.log_completions_folder is None: raise RuntimeError( @@ -168,6 +169,7 @@ def __init__( retry_min_wait=self.config.retry_min_wait, retry_max_wait=self.config.retry_max_wait, retry_multiplier=self.config.retry_multiplier, + retry_listener=self.retry_listener, ) def wrapper(*args, **kwargs): """Wrapper for the litellm completion function. Logs the input and output of the completion function.""" diff --git a/openhands/llm/retry_mixin.py b/openhands/llm/retry_mixin.py index 2942c913a268..714153e4c1a1 100644 --- a/openhands/llm/retry_mixin.py +++ b/openhands/llm/retry_mixin.py @@ -28,9 +28,15 @@ def retry_decorator(self, **kwargs): retry_min_wait = kwargs.get('retry_min_wait') retry_max_wait = kwargs.get('retry_max_wait') retry_multiplier = kwargs.get('retry_multiplier') + retry_listener = kwargs.get('retry_listener') + + def before_sleep(retry_state): + self.log_retry_attempt(retry_state) + if retry_listener: + retry_listener(retry_state.attempt_number, num_retries) return retry( - before_sleep=self.log_retry_attempt, + before_sleep=before_sleep, stop=stop_after_attempt(num_retries) | stop_if_should_exit(), reraise=True, retry=(retry_if_exception_type(retry_exceptions)), diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index a7a16a4fa6ec..dd1ab777aa4e 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -104,7 +104,7 @@ async def initialize_agent( # TODO: override other LLM config & agent config groups (#2075) - llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls)) + llm = self._create_llm(agent_cls) agent_config = self.config.get_agent_config(agent_cls) if settings.enable_default_condenser: @@ -142,6 +142,21 @@ async def initialize_agent( ) return + def _create_llm(self, agent_cls: str | None) -> LLM: + """ + Initialize LLM, extracted for testing. + """ + return LLM( + config=self.config.get_llm_config_from_agent(agent_cls), + retry_listener=self._notify_on_llm_retry, + ) + + def _notify_on_llm_retry(self, retries: int, max: int) -> None: + msg_id = 'STATUS$LLM_RETRY' + self.queue_status_message( + 'info', msg_id, f'Retrying LLM request, {retries} / {max}' + ) + def on_event(self, event: Event): asyncio.get_event_loop().run_until_complete(self._on_event(event)) @@ -220,7 +235,6 @@ async def _send_status_message(self, msg_type: str, id: str, message: str): """Sends a status message to the client.""" if msg_type == 'error': await self.agent_session.stop_agent_loop_for_error() - await self.send( {'status_update': True, 'type': msg_type, 'id': id, 'message': message} ) diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 3579247572bf..3c5ca97df14b 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock from uuid import uuid4 import pytest @@ -564,6 +564,22 @@ def on_event(event: Event): ), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}' +@pytest.mark.asyncio +async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + status_callback=mock_status_callback, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + controller._notify_on_llm_retry(1, 2) + controller.status_callback.assert_called_once_with('info', 'STATUS$LLM_RETRY', ANY) + await controller.close() + + @pytest.mark.asyncio async def test_context_window_exceeded_error_handling(mock_agent, mock_event_stream): """Test that context window exceeded errors are handled correctly by truncating history.""" diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000000..7f61e66f01d0 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,69 @@ +from unittest.mock import ANY, AsyncMock, patch + +import pytest +from litellm.exceptions import ( + RateLimitError, +) + +from openhands.core.config.app_config import AppConfig +from openhands.core.config.llm_config import LLMConfig +from openhands.server.session.session import Session +from openhands.storage.memory import InMemoryFileStore + + +@pytest.fixture +def mock_status_callback(): + return AsyncMock() + + +@pytest.fixture +def mock_sio(): + return AsyncMock() + + +@pytest.fixture +def default_llm_config(): + return LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) + + +@pytest.mark.asyncio +@patch('openhands.llm.llm.litellm_completion') +async def test_notify_on_llm_retry( + mock_litellm_completion, mock_sio, default_llm_config +): + config = AppConfig() + config.set_llm_config(default_llm_config) + session = Session( + sid='..sid..', + file_store=InMemoryFileStore({}), + config=config, + sio=mock_sio, + user_id='..uid..', + ) + session.queue_status_message = AsyncMock() + + with patch('time.sleep') as _mock_sleep: + mock_litellm_completion.side_effect = [ + RateLimitError( + 'Rate limit exceeded', llm_provider='test_provider', model='test_model' + ), + {'choices': [{'message': {'content': 'Retry successful'}}]}, + ] + llm = session._create_llm('..cls..') + + llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert mock_litellm_completion.call_count == 2 + session.queue_status_message.assert_called_once_with( + 'info', 'STATUS$LLM_RETRY', ANY + ) + await session.close()