diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index 477b47ccdbe1..1cb9af2f3042 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -77,6 +77,7 @@ class LLMConfig: log_completions: bool = False log_completions_folder: str = os.path.join(LOG_DIR, 'completions') draft_editor: Optional['LLMConfig'] = None + fallback_llms: list['LLMConfig'] | None = None # List of LLM configs to try when rate limits are hit def defaults_to_dict(self) -> dict: """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" @@ -121,6 +122,8 @@ def to_safe_dict(self): ret[k] = '******' if v else None elif isinstance(v, LLMConfig): ret[k] = v.to_safe_dict() + elif k == 'fallback_llms' and v is not None: + ret[k] = [llm.to_safe_dict() for llm in v] return ret @classmethod @@ -128,10 +131,20 @@ def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig': """Create an LLMConfig object from a dictionary. This function is used to create an LLMConfig object from a dictionary, - with the exception of the 'draft_editor' key, which is a nested LLMConfig object. + with the exception of the 'draft_editor' and 'fallback_llms' keys, which are nested LLMConfig objects. """ - args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, dict)} + args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, (dict, list))} + + # Handle draft_editor if 'draft_editor' in llm_config_dict: draft_editor_config = LLMConfig(**llm_config_dict['draft_editor']) args['draft_editor'] = draft_editor_config + + # Handle fallback_llms + if 'fallback_llms' in llm_config_dict: + fallback_configs = [ + LLMConfig(**llm_dict) for llm_dict in llm_config_dict['fallback_llms'] + ] + args['fallback_llms'] = fallback_configs + return cls(**args) diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 85010b3fec73..27ed211337c3 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -1,5 +1,6 @@ import copy import os +import threading import time import warnings from functools import partial @@ -93,6 +94,10 @@ def __init__( config: The LLM configuration. metrics: The metrics to use. """ + self._current_llm_index = 0 # Index of current LLM in fallback list + self._fallback_llms: list[LLM] | None = None + if config.fallback_llms: + self._fallback_llms = [LLM(cfg, metrics) for cfg in config.fallback_llms] self._tried_model_info = False self.metrics: Metrics = ( metrics if metrics is not None else Metrics(model_name=config.model) @@ -200,7 +205,39 @@ def wrapper(*args, **kwargs): try: # we don't support streaming here, thus we get a ModelResponse - resp: ModelResponse = self._completion_unwrapped(*args, **kwargs) + try: + resp: ModelResponse = self._completion_unwrapped(*args, **kwargs) + except RateLimitError as e: + # If we have fallback LLMs, try them + if self._fallback_llms: + # Extract wait time from error message + wait_time = None + if 'Please try again in' in str(e): + import re + match = re.search(r'try again in (\d+[.]\d+)s', str(e)) + if match: + wait_time = float(match.group(1)) + + # Try next fallback LLM + self._current_llm_index += 1 + if self._current_llm_index < len(self._fallback_llms): + fallback_llm = self._fallback_llms[self._current_llm_index] + + # Log the rate limit and switch + logger.warning( + f'Rate limit hit for {self.config.model}. ' + f'Wait time: {wait_time}s. ' + f'Switching to fallback LLM {fallback_llm.config.model}' + ) + + # Schedule reset when rate limit expires + if wait_time is not None: + self.schedule_reset_fallback(wait_time) + + return fallback_llm.completion(*args, **kwargs) + + # No more fallbacks, re-raise + raise non_fncall_response = copy.deepcopy(resp) if mock_function_calling: @@ -374,6 +411,33 @@ def vision_is_active(self) -> bool: warnings.simplefilter('ignore') return not self.config.disable_vision and self._supports_vision() + def get_current_model(self) -> str: + """Get the name of the currently active LLM model. + + Returns: + str: The model name of the currently active LLM. + """ + if self._fallback_llms and self._current_llm_index > 0: + return self._fallback_llms[self._current_llm_index].config.model + return self.config.model + + def reset_fallback_index(self): + """Reset the fallback LLM index to use the primary LLM again.""" + self._current_llm_index = 0 + logger.info(f'Rate limit expired. Switching back to primary LLM {self.config.model}') + + def schedule_reset_fallback(self, wait_time: float): + """Schedule resetting the fallback LLM index after the rate limit expires. + + Args: + wait_time: Time to wait in seconds before resetting. + """ + def reset_after_wait(): + time.sleep(wait_time) + self.reset_fallback_index() + + threading.Thread(target=reset_after_wait, daemon=True).start() + def _supports_vision(self) -> bool: """Acquire from litellm if model is vision capable. diff --git a/tests/unit/test_llm_fallback.py b/tests/unit/test_llm_fallback.py new file mode 100644 index 000000000000..936faebab99f --- /dev/null +++ b/tests/unit/test_llm_fallback.py @@ -0,0 +1,71 @@ +import pytest +from unittest.mock import patch, MagicMock + +from litellm.exceptions import RateLimitError +from openhands.core.config import LLMConfig +from openhands.llm.llm import LLM + + +def test_llm_fallback_init(): + # Test that fallback LLMs are properly initialized + primary_config = LLMConfig(model='model1') + fallback1 = LLMConfig(model='model2') + fallback2 = LLMConfig(model='model3') + primary_config.fallback_llms = [fallback1, fallback2] + + llm = LLM(primary_config) + assert llm.get_current_model() == 'model1' + assert len(llm._fallback_llms) == 2 + assert llm._fallback_llms[0].config.model == 'model2' + assert llm._fallback_llms[1].config.model == 'model3' + + +def test_llm_fallback_on_rate_limit(): + # Test that LLM switches to fallback on rate limit error + primary_config = LLMConfig(model='model1') + fallback1 = LLMConfig(model='model2') + primary_config.fallback_llms = [fallback1] + + llm = LLM(primary_config) + + # Mock the completion functions + primary_error = RateLimitError('Please try again in 60.5s') + llm._completion_unwrapped = MagicMock(side_effect=primary_error) + llm._fallback_llms[0]._completion_unwrapped = MagicMock(return_value={'choices': [{'message': {'content': 'success'}}]}) + + # Call completion and verify fallback is used + result = llm.completion(messages=[{'role': 'user', 'content': 'test'}]) + assert result['choices'][0]['message']['content'] == 'success' + assert llm.get_current_model() == 'model2' + + +def test_llm_fallback_reset(): + # Test that LLM resets to primary after rate limit expires + primary_config = LLMConfig(model='model1') + fallback1 = LLMConfig(model='model2') + primary_config.fallback_llms = [fallback1] + + llm = LLM(primary_config) + llm._current_llm_index = 1 # Simulate using fallback + + # Reset and verify + llm.reset_fallback_index() + assert llm.get_current_model() == 'model1' + + +def test_llm_no_more_fallbacks(): + # Test that error is re-raised when no more fallbacks are available + primary_config = LLMConfig(model='model1') + fallback1 = LLMConfig(model='model2') + primary_config.fallback_llms = [fallback1] + + llm = LLM(primary_config) + + # Mock both LLMs to fail + error = RateLimitError('Rate limit exceeded') + llm._completion_unwrapped = MagicMock(side_effect=error) + llm._fallback_llms[0]._completion_unwrapped = MagicMock(side_effect=error) + + # Verify error is raised when no more fallbacks + with pytest.raises(RateLimitError): + llm.completion(messages=[{'role': 'user', 'content': 'test'}]) \ No newline at end of file