Skip to content

Commit

Permalink
Fix issue #4184: '[LLM] Support LLM routing through notdiamond'
Browse files Browse the repository at this point in the history
  • Loading branch information
openhands-agent committed Oct 3, 2024
1 parent 5c31fd9 commit 4fb3b0d
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 1 deletion.
2 changes: 2 additions & 0 deletions openhands/core/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ class ConfigType(str, Enum):
WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH'
WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX'
WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE'
LLM_PROVIDERS = 'LLM_PROVIDERS'
LLM_ROUTER_ENABLED = 'LLM_ROUTER_ENABLED'
44 changes: 43 additions & 1 deletion openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import warnings
from functools import partial
from typing import Any
from typing import Any, List, Tuple

from openhands.core.config import LLMConfig

Expand All @@ -26,6 +26,7 @@
from openhands.core.metrics import Metrics
from openhands.llm.debug_mixin import DebugMixin
from openhands.llm.retry_mixin import RetryMixin
from openhands.llm.llm_router import LLMRouter

__all__ = ['LLM']

Expand Down Expand Up @@ -77,6 +78,11 @@ def __init__(
# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
# - 'messages': list of messages
# - 'response': response from the LLM

if self.config.llm_router_enabled:
self.router = LLMRouter(config, metrics)
else:
self.router = None
self.llm_completions: list[dict[str, Any]] = []

# litellm actually uses base Exception here for unknown model
Expand Down Expand Up @@ -123,6 +129,7 @@ def __init__(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key,

base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
Expand Down Expand Up @@ -173,6 +180,7 @@ def wrapper(*args, **kwargs):
if not messages:
raise ValueError(
'The messages list is empty. At least one message is required.'

)

# log the entire LLM prompt
Expand Down Expand Up @@ -211,6 +219,40 @@ def wrapper(*args, **kwargs):

self._completion = wrapper


def complete(
self,
messages: List[Message],
**kwargs: Any,
) -> Tuple[str, float]:
"""Complete the given messages using the best selected model or the default model."""
start_time = time.time()

if self.router:
response, _ = self.router.complete(messages, **kwargs)
else:
response = self._completion(
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
**kwargs
)

latency = time.time() - start_time
return response.choices[0].message.content, latency

def stream(
self,
messages: List[Message],
**kwargs: Any,
):
"""Stream the response using the best selected model or the default model."""
if self.router:
yield from self.router.stream(messages, **kwargs)
else:
yield from self._completion(
messages=[{"role": msg.role, "content": msg.content} for msg in messages],
stream=True,
**kwargs
)
@property
def completion(self):
"""Decorator for the litellm completion function.
Expand Down
65 changes: 65 additions & 0 deletions openhands/llm/llm_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

import os
from typing import List, Tuple, Any
from openhands.core.config import LLMConfig
from openhands.llm.llm import LLM
from openhands.core.message import Message
from openhands.core.metrics import Metrics

class LLMRouter(LLM):
"""LLMRouter class that selects the best LLM for a given query."""

def __init__(
self,
config: LLMConfig,
metrics: Metrics | None = None,
):
super().__init__(config, metrics)
self.llm_providers: List[str] = config.llm_providers
self.notdiamond_api_key = os.environ.get("NOTDIAMOND_API_KEY")
if not self.notdiamond_api_key:
raise ValueError("NOTDIAMOND_API_KEY environment variable is not set")

from notdiamond import NotDiamond
self.client = NotDiamond()

def _select_model(self, messages: List[Message]) -> Tuple[str, Any]:
"""Select the best model for the given messages."""
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
session_id, provider = self.client.chat.completions.model_select(
messages=formatted_messages,
model=self.llm_providers
)
return provider.model, session_id

def complete(
self,
messages: List[Message],
**kwargs: Any,
) -> Tuple[str, float]:
"""Complete the given messages using the best selected model."""
selected_model, session_id = self._select_model(messages)

# Create a new LLM instance with the selected model
selected_config = LLMConfig(model=selected_model)
selected_llm = LLM(config=selected_config, metrics=self.metrics)

# Use the selected LLM to complete the messages
response, latency = selected_llm.complete(messages, **kwargs)

return response, latency

def stream(
self,
messages: List[Message],
**kwargs: Any,
):
"""Stream the response using the best selected model."""
selected_model, session_id = self._select_model(messages)

# Create a new LLM instance with the selected model
selected_config = LLMConfig(model=selected_model)
selected_llm = LLM(config=selected_config, metrics=self.metrics)

# Use the selected LLM to stream the response
yield from selected_llm.stream(messages, **kwargs)
54 changes: 54 additions & 0 deletions tests/unit/test_llm_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@

import pytest
from unittest.mock import Mock, patch
from openhands.core.config import LLMConfig
from openhands.core.message import Message
from openhands.llm.llm import LLM
from openhands.llm.llm_router import LLMRouter

@pytest.fixture
def mock_notdiamond():
with patch('openhands.llm.llm_router.NotDiamond') as mock:
yield mock

def test_llm_router_enabled(mock_notdiamond):
config = LLMConfig(
model="test-model",
llm_router_enabled=True,
llm_providers=["model1", "model2"]
)
llm = LLM(config)

assert isinstance(llm.router, LLMRouter)

messages = [Message(role="user", content="Hello")]
mock_response = Mock()
mock_response.choices[0].message.content = "Hello, how can I help you?"
llm.router.complete = Mock(return_value=(mock_response, 0.5))

response, latency = llm.complete(messages)

assert response == "Hello, how can I help you?"
assert isinstance(latency, float)
llm.router.complete.assert_called_once_with(messages)

def test_llm_router_disabled():
config = LLMConfig(
model="test-model",
llm_router_enabled=False
)
llm = LLM(config)

assert llm.router is None

messages = [Message(role="user", content="Hello")]
with patch.object(llm, '_completion') as mock_completion:
mock_response = Mock()
mock_response.choices[0].message.content = "Hello, how can I help you?"
mock_completion.return_value = mock_response

response, latency = llm.complete(messages)

assert response == "Hello, how can I help you?"
assert isinstance(latency, float)
mock_completion.assert_called_once()

0 comments on commit 4fb3b0d

Please sign in to comment.