From 4fb3b0dcd580ac78fec333cf38d455393028da93 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 3 Oct 2024 11:44:20 +0000 Subject: [PATCH] Fix issue #4184: '[LLM] Support LLM routing through notdiamond' --- openhands/core/schema/config.py | 2 + openhands/llm/llm.py | 44 +++++++++++++++++++++- openhands/llm/llm_router.py | 65 +++++++++++++++++++++++++++++++++ tests/unit/test_llm_router.py | 54 +++++++++++++++++++++++++++ 4 files changed, 164 insertions(+), 1 deletion(-) create mode 100644 openhands/llm/llm_router.py create mode 100644 tests/unit/test_llm_router.py diff --git a/openhands/core/schema/config.py b/openhands/core/schema/config.py index 1272ebe655a5..d5ca3568f579 100644 --- a/openhands/core/schema/config.py +++ b/openhands/core/schema/config.py @@ -47,3 +47,5 @@ class ConfigType(str, Enum): WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH' WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX' WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE' + LLM_PROVIDERS = 'LLM_PROVIDERS' + LLM_ROUTER_ENABLED = 'LLM_ROUTER_ENABLED' diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 11a34a51218a..41b90b68aebf 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -2,7 +2,7 @@ import time import warnings from functools import partial -from typing import Any +from typing import Any, List, Tuple from openhands.core.config import LLMConfig @@ -26,6 +26,7 @@ from openhands.core.metrics import Metrics from openhands.llm.debug_mixin import DebugMixin from openhands.llm.retry_mixin import RetryMixin +from openhands.llm.llm_router import LLMRouter __all__ = ['LLM'] @@ -77,6 +78,11 @@ def __init__( # list of LLM completions (for logging purposes). Each completion is a dict with the following keys: # - 'messages': list of messages # - 'response': response from the LLM + + if self.config.llm_router_enabled: + self.router = LLMRouter(config, metrics) + else: + self.router = None self.llm_completions: list[dict[str, Any]] = [] # litellm actually uses base Exception here for unknown model @@ -123,6 +129,7 @@ def __init__( litellm_completion, model=self.config.model, api_key=self.config.api_key, + base_url=self.config.base_url, api_version=self.config.api_version, custom_llm_provider=self.config.custom_llm_provider, @@ -173,6 +180,7 @@ def wrapper(*args, **kwargs): if not messages: raise ValueError( 'The messages list is empty. At least one message is required.' + ) # log the entire LLM prompt @@ -211,6 +219,40 @@ def wrapper(*args, **kwargs): self._completion = wrapper + + def complete( + self, + messages: List[Message], + **kwargs: Any, + ) -> Tuple[str, float]: + """Complete the given messages using the best selected model or the default model.""" + start_time = time.time() + + if self.router: + response, _ = self.router.complete(messages, **kwargs) + else: + response = self._completion( + messages=[{"role": msg.role, "content": msg.content} for msg in messages], + **kwargs + ) + + latency = time.time() - start_time + return response.choices[0].message.content, latency + + def stream( + self, + messages: List[Message], + **kwargs: Any, + ): + """Stream the response using the best selected model or the default model.""" + if self.router: + yield from self.router.stream(messages, **kwargs) + else: + yield from self._completion( + messages=[{"role": msg.role, "content": msg.content} for msg in messages], + stream=True, + **kwargs + ) @property def completion(self): """Decorator for the litellm completion function. diff --git a/openhands/llm/llm_router.py b/openhands/llm/llm_router.py new file mode 100644 index 000000000000..0b4a0893af7a --- /dev/null +++ b/openhands/llm/llm_router.py @@ -0,0 +1,65 @@ + +import os +from typing import List, Tuple, Any +from openhands.core.config import LLMConfig +from openhands.llm.llm import LLM +from openhands.core.message import Message +from openhands.core.metrics import Metrics + +class LLMRouter(LLM): + """LLMRouter class that selects the best LLM for a given query.""" + + def __init__( + self, + config: LLMConfig, + metrics: Metrics | None = None, + ): + super().__init__(config, metrics) + self.llm_providers: List[str] = config.llm_providers + self.notdiamond_api_key = os.environ.get("NOTDIAMOND_API_KEY") + if not self.notdiamond_api_key: + raise ValueError("NOTDIAMOND_API_KEY environment variable is not set") + + from notdiamond import NotDiamond + self.client = NotDiamond() + + def _select_model(self, messages: List[Message]) -> Tuple[str, Any]: + """Select the best model for the given messages.""" + formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages] + session_id, provider = self.client.chat.completions.model_select( + messages=formatted_messages, + model=self.llm_providers + ) + return provider.model, session_id + + def complete( + self, + messages: List[Message], + **kwargs: Any, + ) -> Tuple[str, float]: + """Complete the given messages using the best selected model.""" + selected_model, session_id = self._select_model(messages) + + # Create a new LLM instance with the selected model + selected_config = LLMConfig(model=selected_model) + selected_llm = LLM(config=selected_config, metrics=self.metrics) + + # Use the selected LLM to complete the messages + response, latency = selected_llm.complete(messages, **kwargs) + + return response, latency + + def stream( + self, + messages: List[Message], + **kwargs: Any, + ): + """Stream the response using the best selected model.""" + selected_model, session_id = self._select_model(messages) + + # Create a new LLM instance with the selected model + selected_config = LLMConfig(model=selected_model) + selected_llm = LLM(config=selected_config, metrics=self.metrics) + + # Use the selected LLM to stream the response + yield from selected_llm.stream(messages, **kwargs) diff --git a/tests/unit/test_llm_router.py b/tests/unit/test_llm_router.py new file mode 100644 index 000000000000..47c3e7a4fba9 --- /dev/null +++ b/tests/unit/test_llm_router.py @@ -0,0 +1,54 @@ + +import pytest +from unittest.mock import Mock, patch +from openhands.core.config import LLMConfig +from openhands.core.message import Message +from openhands.llm.llm import LLM +from openhands.llm.llm_router import LLMRouter + +@pytest.fixture +def mock_notdiamond(): + with patch('openhands.llm.llm_router.NotDiamond') as mock: + yield mock + +def test_llm_router_enabled(mock_notdiamond): + config = LLMConfig( + model="test-model", + llm_router_enabled=True, + llm_providers=["model1", "model2"] + ) + llm = LLM(config) + + assert isinstance(llm.router, LLMRouter) + + messages = [Message(role="user", content="Hello")] + mock_response = Mock() + mock_response.choices[0].message.content = "Hello, how can I help you?" + llm.router.complete = Mock(return_value=(mock_response, 0.5)) + + response, latency = llm.complete(messages) + + assert response == "Hello, how can I help you?" + assert isinstance(latency, float) + llm.router.complete.assert_called_once_with(messages) + +def test_llm_router_disabled(): + config = LLMConfig( + model="test-model", + llm_router_enabled=False + ) + llm = LLM(config) + + assert llm.router is None + + messages = [Message(role="user", content="Hello")] + with patch.object(llm, '_completion') as mock_completion: + mock_response = Mock() + mock_response.choices[0].message.content = "Hello, how can I help you?" + mock_completion.return_value = mock_response + + response, latency = llm.complete(messages) + + assert response == "Hello, how can I help you?" + assert isinstance(latency, float) + mock_completion.assert_called_once()