Skip to content

Commit

Permalink
Show LLM retries and allow resume from rate-limit state (All-Hands-AI…
Browse files Browse the repository at this point in the history
…#6438)

Co-authored-by: Engel Nyst <[email protected]>
  • Loading branch information
raymyers and enyst authored Jan 30, 2025
1 parent 1bccfb3 commit fd73f42
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 11 deletions.
3 changes: 1 addition & 2 deletions frontend/src/components/features/chat/chat-interface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ export function ChatInterface() {
onStop={handleStop}
isDisabled={
curAgentState === AgentState.LOADING ||
curAgentState === AgentState.AWAITING_USER_CONFIRMATION ||
curAgentState === AgentState.RATE_LIMITED
curAgentState === AgentState.AWAITING_USER_CONFIRMATION
}
mode={curAgentState === AgentState.RUNNING ? "stop" : "submit"}
value={messageToSend ?? undefined}
Expand Down
3 changes: 3 additions & 0 deletions frontend/src/i18n/translation.json
Original file line number Diff line number Diff line change
Expand Up @@ -3816,6 +3816,9 @@
"es": "Hubo un error al conectar con el entorno de ejecución. Por favor, actualice la página.",
"tr": "Çalışma zamanına bağlanırken bir hata oluştu. Lütfen sayfayı yenileyin."
},
"STATUS$LLM_RETRY": {
"en": "Retrying LLM request"
},
"AGENT_ERROR$BAD_ACTION": {
"en": "Agent tried to execute a malformed action.",
"zh-CN": "错误的操作",
Expand Down
15 changes: 12 additions & 3 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,10 @@ async def _step_with_exception_handling(self):
f'report this error to the developers. Your session ID is {self.id}. '
f'Error type: {e.__class__.__name__}'
)
if isinstance(e, litellm.AuthenticationError) or isinstance(
e, litellm.BadRequestError
if (
isinstance(e, litellm.AuthenticationError)
or isinstance(e, litellm.BadRequestError)
or isinstance(e, RateLimitError)
):
reported = e
await self._react_to_exception(reported)
Expand Down Expand Up @@ -530,7 +532,7 @@ async def start_delegate(self, action: AgentDelegateAction) -> None:
agent_cls: Type[Agent] = Agent.get_cls(action.agent)
agent_config = self.agent_configs.get(action.agent, self.agent.config)
llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
llm = LLM(config=llm_config)
llm = LLM(config=llm_config, retry_listener=self._notify_on_llm_retry)
delegate_agent = agent_cls(llm=llm, config=agent_config)
state = State(
inputs=action.inputs or {},
Expand Down Expand Up @@ -725,6 +727,13 @@ async def _step(self) -> None:
log_level = 'info' if LOG_ALL_EVENTS else 'debug'
self.log(log_level, str(action), extra={'msg_type': 'ACTION'})

def _notify_on_llm_retry(self, retries: int, max: int) -> None:
if self.status_callback is not None:
msg_id = 'STATUS$LLM_RETRY'
self.status_callback(
'info', msg_id, f'Retrying LLM request, {retries} / {max}'
)

async def _handle_traffic_control(
self, limit_type: str, current_value: float, max_value: float
) -> bool:
Expand Down
6 changes: 4 additions & 2 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import warnings
from functools import partial
from typing import Any
from typing import Any, Callable

import requests

Expand Down Expand Up @@ -94,6 +94,7 @@ def __init__(
self,
config: LLMConfig,
metrics: Metrics | None = None,
retry_listener: Callable[[int, int], None] | None = None,
):
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
Expand All @@ -111,7 +112,7 @@ def __init__(
self.config: LLMConfig = copy.deepcopy(config)

self.model_info: ModelInfo | None = None

self.retry_listener = retry_listener
if self.config.log_completions:
if self.config.log_completions_folder is None:
raise RuntimeError(
Expand Down Expand Up @@ -168,6 +169,7 @@ def __init__(
retry_min_wait=self.config.retry_min_wait,
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
retry_listener=self.retry_listener,
)
def wrapper(*args, **kwargs):
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
Expand Down
8 changes: 7 additions & 1 deletion openhands/llm/retry_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ def retry_decorator(self, **kwargs):
retry_min_wait = kwargs.get('retry_min_wait')
retry_max_wait = kwargs.get('retry_max_wait')
retry_multiplier = kwargs.get('retry_multiplier')
retry_listener = kwargs.get('retry_listener')

def before_sleep(retry_state):
self.log_retry_attempt(retry_state)
if retry_listener:
retry_listener(retry_state.attempt_number, num_retries)

return retry(
before_sleep=self.log_retry_attempt,
before_sleep=before_sleep,
stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
reraise=True,
retry=(retry_if_exception_type(retry_exceptions)),
Expand Down
18 changes: 16 additions & 2 deletions openhands/server/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def initialize_agent(

# TODO: override other LLM config & agent config groups (#2075)

llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
llm = self._create_llm(agent_cls)
agent_config = self.config.get_agent_config(agent_cls)

if settings.enable_default_condenser:
Expand Down Expand Up @@ -142,6 +142,21 @@ async def initialize_agent(
)
return

def _create_llm(self, agent_cls: str | None) -> LLM:
"""
Initialize LLM, extracted for testing.
"""
return LLM(
config=self.config.get_llm_config_from_agent(agent_cls),
retry_listener=self._notify_on_llm_retry,
)

def _notify_on_llm_retry(self, retries: int, max: int) -> None:
msg_id = 'STATUS$LLM_RETRY'
self.queue_status_message(
'info', msg_id, f'Retrying LLM request, {retries} / {max}'
)

def on_event(self, event: Event):
asyncio.get_event_loop().run_until_complete(self._on_event(event))

Expand Down Expand Up @@ -220,7 +235,6 @@ async def _send_status_message(self, msg_type: str, id: str, message: str):
"""Sends a status message to the client."""
if msg_type == 'error':
await self.agent_session.stop_agent_loop_for_error()

await self.send(
{'status_update': True, 'type': msg_type, 'id': id, 'message': message}
)
Expand Down
18 changes: 17 additions & 1 deletion tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import ANY, AsyncMock, MagicMock
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -564,6 +564,22 @@ def on_event(event: Event):
), f'Expected accumulated cost to be 30.0, but got {state.metrics.accumulated_cost}'


@pytest.mark.asyncio
async def test_notify_on_llm_retry(mock_agent, mock_event_stream, mock_status_callback):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
status_callback=mock_status_callback,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
controller._notify_on_llm_retry(1, 2)
controller.status_callback.assert_called_once_with('info', 'STATUS$LLM_RETRY', ANY)
await controller.close()


@pytest.mark.asyncio
async def test_context_window_exceeded_error_handling(mock_agent, mock_event_stream):
"""Test that context window exceeded errors are handled correctly by truncating history."""
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from unittest.mock import ANY, AsyncMock, patch

import pytest
from litellm.exceptions import (
RateLimitError,
)

from openhands.core.config.app_config import AppConfig
from openhands.core.config.llm_config import LLMConfig
from openhands.server.session.session import Session
from openhands.storage.memory import InMemoryFileStore


@pytest.fixture
def mock_status_callback():
return AsyncMock()


@pytest.fixture
def mock_sio():
return AsyncMock()


@pytest.fixture
def default_llm_config():
return LLMConfig(
model='gpt-4o',
api_key='test_key',
num_retries=2,
retry_min_wait=1,
retry_max_wait=2,
)


@pytest.mark.asyncio
@patch('openhands.llm.llm.litellm_completion')
async def test_notify_on_llm_retry(
mock_litellm_completion, mock_sio, default_llm_config
):
config = AppConfig()
config.set_llm_config(default_llm_config)
session = Session(
sid='..sid..',
file_store=InMemoryFileStore({}),
config=config,
sio=mock_sio,
user_id='..uid..',
)
session.queue_status_message = AsyncMock()

with patch('time.sleep') as _mock_sleep:
mock_litellm_completion.side_effect = [
RateLimitError(
'Rate limit exceeded', llm_provider='test_provider', model='test_model'
),
{'choices': [{'message': {'content': 'Retry successful'}}]},
]
llm = session._create_llm('..cls..')

llm.completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
)

assert mock_litellm_completion.call_count == 2
session.queue_status_message.assert_called_once_with(
'info', 'STATUS$LLM_RETRY', ANY
)
await session.close()

0 comments on commit fd73f42

Please sign in to comment.