Skip to content

Commit

Permalink
feat: Configure fallback LLMs for rate limit handling
Browse files Browse the repository at this point in the history
- Add fallback_llms field to LLMConfig
- Implement automatic switching to fallback LLMs on rate limits
- Add automatic reset when rate limit expires
- Add unit tests for fallback functionality

Fixes #1263
  • Loading branch information
openhands-agent committed Dec 9, 2024
1 parent 99fa6c6 commit da0d740
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 3 deletions.
17 changes: 15 additions & 2 deletions openhands/core/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class LLMConfig:
log_completions: bool = False
log_completions_folder: str = os.path.join(LOG_DIR, 'completions')
draft_editor: Optional['LLMConfig'] = None
fallback_llms: list['LLMConfig'] | None = None # List of LLM configs to try when rate limits are hit

def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
Expand Down Expand Up @@ -121,17 +122,29 @@ def to_safe_dict(self):
ret[k] = '******' if v else None
elif isinstance(v, LLMConfig):
ret[k] = v.to_safe_dict()
elif k == 'fallback_llms' and v is not None:
ret[k] = [llm.to_safe_dict() for llm in v]
return ret

@classmethod
def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig':
"""Create an LLMConfig object from a dictionary.
This function is used to create an LLMConfig object from a dictionary,
with the exception of the 'draft_editor' key, which is a nested LLMConfig object.
with the exception of the 'draft_editor' and 'fallback_llms' keys, which are nested LLMConfig objects.
"""
args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, dict)}
args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, (dict, list))}

# Handle draft_editor
if 'draft_editor' in llm_config_dict:
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
args['draft_editor'] = draft_editor_config

# Handle fallback_llms
if 'fallback_llms' in llm_config_dict:
fallback_configs = [
LLMConfig(**llm_dict) for llm_dict in llm_config_dict['fallback_llms']
]
args['fallback_llms'] = fallback_configs

return cls(**args)
66 changes: 65 additions & 1 deletion openhands/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import threading
import time
import warnings
from functools import partial
Expand Down Expand Up @@ -93,6 +94,10 @@ def __init__(
config: The LLM configuration.
metrics: The metrics to use.
"""
self._current_llm_index = 0 # Index of current LLM in fallback list
self._fallback_llms: list[LLM] | None = None
if config.fallback_llms:
self._fallback_llms = [LLM(cfg, metrics) for cfg in config.fallback_llms]
self._tried_model_info = False
self.metrics: Metrics = (
metrics if metrics is not None else Metrics(model_name=config.model)
Expand Down Expand Up @@ -200,7 +205,39 @@ def wrapper(*args, **kwargs):

try:
# we don't support streaming here, thus we get a ModelResponse
resp: ModelResponse = self._completion_unwrapped(*args, **kwargs)
try:
resp: ModelResponse = self._completion_unwrapped(*args, **kwargs)
except RateLimitError as e:
# If we have fallback LLMs, try them
if self._fallback_llms:
# Extract wait time from error message
wait_time = None
if 'Please try again in' in str(e):
import re
match = re.search(r'try again in (\d+[.]\d+)s', str(e))
if match:
wait_time = float(match.group(1))

# Try next fallback LLM
self._current_llm_index += 1
if self._current_llm_index < len(self._fallback_llms):
fallback_llm = self._fallback_llms[self._current_llm_index]

# Log the rate limit and switch
logger.warning(
f'Rate limit hit for {self.config.model}. '
f'Wait time: {wait_time}s. '
f'Switching to fallback LLM {fallback_llm.config.model}'
)

# Schedule reset when rate limit expires
if wait_time is not None:
self.schedule_reset_fallback(wait_time)

return fallback_llm.completion(*args, **kwargs)

# No more fallbacks, re-raise
raise

non_fncall_response = copy.deepcopy(resp)
if mock_function_calling:
Expand Down Expand Up @@ -374,6 +411,33 @@ def vision_is_active(self) -> bool:
warnings.simplefilter('ignore')
return not self.config.disable_vision and self._supports_vision()

def get_current_model(self) -> str:
"""Get the name of the currently active LLM model.
Returns:
str: The model name of the currently active LLM.
"""
if self._fallback_llms and self._current_llm_index > 0:
return self._fallback_llms[self._current_llm_index].config.model
return self.config.model

def reset_fallback_index(self):
"""Reset the fallback LLM index to use the primary LLM again."""
self._current_llm_index = 0
logger.info(f'Rate limit expired. Switching back to primary LLM {self.config.model}')

def schedule_reset_fallback(self, wait_time: float):
"""Schedule resetting the fallback LLM index after the rate limit expires.
Args:
wait_time: Time to wait in seconds before resetting.
"""
def reset_after_wait():
time.sleep(wait_time)
self.reset_fallback_index()

threading.Thread(target=reset_after_wait, daemon=True).start()

def _supports_vision(self) -> bool:
"""Acquire from litellm if model is vision capable.
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/test_llm_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
from unittest.mock import patch, MagicMock

from litellm.exceptions import RateLimitError
from openhands.core.config import LLMConfig
from openhands.llm.llm import LLM


def test_llm_fallback_init():
# Test that fallback LLMs are properly initialized
primary_config = LLMConfig(model='model1')
fallback1 = LLMConfig(model='model2')
fallback2 = LLMConfig(model='model3')
primary_config.fallback_llms = [fallback1, fallback2]

llm = LLM(primary_config)
assert llm.get_current_model() == 'model1'
assert len(llm._fallback_llms) == 2
assert llm._fallback_llms[0].config.model == 'model2'
assert llm._fallback_llms[1].config.model == 'model3'


def test_llm_fallback_on_rate_limit():
# Test that LLM switches to fallback on rate limit error
primary_config = LLMConfig(model='model1')
fallback1 = LLMConfig(model='model2')
primary_config.fallback_llms = [fallback1]

llm = LLM(primary_config)

# Mock the completion functions
primary_error = RateLimitError('Please try again in 60.5s')
llm._completion_unwrapped = MagicMock(side_effect=primary_error)
llm._fallback_llms[0]._completion_unwrapped = MagicMock(return_value={'choices': [{'message': {'content': 'success'}}]})

# Call completion and verify fallback is used
result = llm.completion(messages=[{'role': 'user', 'content': 'test'}])
assert result['choices'][0]['message']['content'] == 'success'
assert llm.get_current_model() == 'model2'


def test_llm_fallback_reset():
# Test that LLM resets to primary after rate limit expires
primary_config = LLMConfig(model='model1')
fallback1 = LLMConfig(model='model2')
primary_config.fallback_llms = [fallback1]

llm = LLM(primary_config)
llm._current_llm_index = 1 # Simulate using fallback

# Reset and verify
llm.reset_fallback_index()
assert llm.get_current_model() == 'model1'


def test_llm_no_more_fallbacks():
# Test that error is re-raised when no more fallbacks are available
primary_config = LLMConfig(model='model1')
fallback1 = LLMConfig(model='model2')
primary_config.fallback_llms = [fallback1]

llm = LLM(primary_config)

# Mock both LLMs to fail
error = RateLimitError('Rate limit exceeded')
llm._completion_unwrapped = MagicMock(side_effect=error)
llm._fallback_llms[0]._completion_unwrapped = MagicMock(side_effect=error)

# Verify error is raised when no more fallbacks
with pytest.raises(RateLimitError):
llm.completion(messages=[{'role': 'user', 'content': 'test'}])

0 comments on commit da0d740

Please sign in to comment.