-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix issue #4184: '[LLM] Support LLM routing through notdiamond'
- Loading branch information
1 parent
5c31fd9
commit 4fb3b0d
Showing
4 changed files
with
164 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
|
||
import os | ||
from typing import List, Tuple, Any | ||
from openhands.core.config import LLMConfig | ||
from openhands.llm.llm import LLM | ||
from openhands.core.message import Message | ||
from openhands.core.metrics import Metrics | ||
|
||
class LLMRouter(LLM): | ||
"""LLMRouter class that selects the best LLM for a given query.""" | ||
|
||
def __init__( | ||
self, | ||
config: LLMConfig, | ||
metrics: Metrics | None = None, | ||
): | ||
super().__init__(config, metrics) | ||
self.llm_providers: List[str] = config.llm_providers | ||
self.notdiamond_api_key = os.environ.get("NOTDIAMOND_API_KEY") | ||
if not self.notdiamond_api_key: | ||
raise ValueError("NOTDIAMOND_API_KEY environment variable is not set") | ||
|
||
from notdiamond import NotDiamond | ||
self.client = NotDiamond() | ||
|
||
def _select_model(self, messages: List[Message]) -> Tuple[str, Any]: | ||
"""Select the best model for the given messages.""" | ||
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages] | ||
session_id, provider = self.client.chat.completions.model_select( | ||
messages=formatted_messages, | ||
model=self.llm_providers | ||
) | ||
return provider.model, session_id | ||
|
||
def complete( | ||
self, | ||
messages: List[Message], | ||
**kwargs: Any, | ||
) -> Tuple[str, float]: | ||
"""Complete the given messages using the best selected model.""" | ||
selected_model, session_id = self._select_model(messages) | ||
|
||
# Create a new LLM instance with the selected model | ||
selected_config = LLMConfig(model=selected_model) | ||
selected_llm = LLM(config=selected_config, metrics=self.metrics) | ||
|
||
# Use the selected LLM to complete the messages | ||
response, latency = selected_llm.complete(messages, **kwargs) | ||
|
||
return response, latency | ||
|
||
def stream( | ||
self, | ||
messages: List[Message], | ||
**kwargs: Any, | ||
): | ||
"""Stream the response using the best selected model.""" | ||
selected_model, session_id = self._select_model(messages) | ||
|
||
# Create a new LLM instance with the selected model | ||
selected_config = LLMConfig(model=selected_model) | ||
selected_llm = LLM(config=selected_config, metrics=self.metrics) | ||
|
||
# Use the selected LLM to stream the response | ||
yield from selected_llm.stream(messages, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
|
||
import pytest | ||
from unittest.mock import Mock, patch | ||
from openhands.core.config import LLMConfig | ||
from openhands.core.message import Message | ||
from openhands.llm.llm import LLM | ||
from openhands.llm.llm_router import LLMRouter | ||
|
||
@pytest.fixture | ||
def mock_notdiamond(): | ||
with patch('openhands.llm.llm_router.NotDiamond') as mock: | ||
yield mock | ||
|
||
def test_llm_router_enabled(mock_notdiamond): | ||
config = LLMConfig( | ||
model="test-model", | ||
llm_router_enabled=True, | ||
llm_providers=["model1", "model2"] | ||
) | ||
llm = LLM(config) | ||
|
||
assert isinstance(llm.router, LLMRouter) | ||
|
||
messages = [Message(role="user", content="Hello")] | ||
mock_response = Mock() | ||
mock_response.choices[0].message.content = "Hello, how can I help you?" | ||
llm.router.complete = Mock(return_value=(mock_response, 0.5)) | ||
|
||
response, latency = llm.complete(messages) | ||
|
||
assert response == "Hello, how can I help you?" | ||
assert isinstance(latency, float) | ||
llm.router.complete.assert_called_once_with(messages) | ||
|
||
def test_llm_router_disabled(): | ||
config = LLMConfig( | ||
model="test-model", | ||
llm_router_enabled=False | ||
) | ||
llm = LLM(config) | ||
|
||
assert llm.router is None | ||
|
||
messages = [Message(role="user", content="Hello")] | ||
with patch.object(llm, '_completion') as mock_completion: | ||
mock_response = Mock() | ||
mock_response.choices[0].message.content = "Hello, how can I help you?" | ||
mock_completion.return_value = mock_response | ||
|
||
response, latency = llm.complete(messages) | ||
|
||
assert response == "Hello, how can I help you?" | ||
assert isinstance(latency, float) | ||
mock_completion.assert_called_once() |