Skip to content

Commit

Permalink
Simplify chat model using LiteLLM
Browse files Browse the repository at this point in the history
- Use LiteLLM for provider-agnostic interface
- Handle OpenAI 'developer' role conversion
- Simple error handling with retryable errors
- Add comprehensive tests
  • Loading branch information
openhands-agent committed Dec 26, 2024
1 parent 0c3352d commit c42deb6
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 0 deletions.
94 changes: 94 additions & 0 deletions src/wandbot/chat/chat_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Chat model implementation using LiteLLM."""
from typing import List, Dict, Any
import litellm

class ChatModel:
"""Chat model using LiteLLM for provider-agnostic interface."""

def __init__(
self,
model_name: str,
temperature: float = 0.1,
):
"""Initialize chat model.
Args:
model_name: Name of the model to use, in format "provider/model"
e.g., "openai/gpt-4", "anthropic/claude-3", "gemini/gemini-pro"
temperature: Sampling temperature between 0 and 1
"""
self.model_name = model_name
self.temperature = temperature

# Configure LiteLLM
litellm.drop_params = True # Remove unsupported params
litellm.set_verbose = False

def generate_response(
self,
messages: List[Dict[str, str]],
max_tokens: int = 1000,
) -> Dict[str, Any]:
"""Generate a response from the model.
Args:
messages: List of message dictionaries with 'role' and 'content' keys
max_tokens: Maximum number of tokens to generate
Returns:
Dictionary containing:
- content: The generated response text
- total_tokens: Total tokens used
- prompt_tokens: Tokens used in the prompt
- completion_tokens: Tokens used in the completion
- error: Error information if request failed
- model_used: Name of the model that generated the response
"""
# Convert messages for OpenAI-style models
if "openai" in self.model_name:
messages = [
{
"role": "developer" if msg["role"] == "system" else msg["role"],
"content": msg["content"]
}
for msg in messages
]

try:
# Generate response
response = litellm.completion(
model=self.model_name,
messages=messages,
max_tokens=max_tokens,
temperature=self.temperature,
)

return {
"content": response.choices[0].message.content,
"total_tokens": response.usage.total_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"error": None,
"model_used": response.model
}

except Exception as e:
# Determine if error is retryable
error_msg = str(e).lower()
retryable = any(
err_type in error_msg
for err_type in ["timeout", "rate limit", "server", "connection"]
)

return {
"content": "",
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"error": {
"type": type(e).__name__,
"message": str(e),
"retryable": retryable
},
"model_used": self.model_name
}
88 changes: 88 additions & 0 deletions tests/test_chat_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import unittest
from unittest.mock import patch, MagicMock

from wandbot.chat.chat_model import ChatModel

class TestChatModel(unittest.TestCase):
def setUp(self):
self.model = ChatModel(model_name="openai/gpt-4")
self.test_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"}
]

def test_openai_role_conversion(self):
"""Test that system role is converted to developer for OpenAI models."""
with patch('litellm.completion') as mock_completion:
mock_response = MagicMock()
mock_response.choices = [
MagicMock(message=MagicMock(content="Hi!"))
]
mock_response.usage = MagicMock(
total_tokens=10,
prompt_tokens=8,
completion_tokens=2
)
mock_response.model = "openai/gpt-4"
mock_completion.return_value = mock_response

self.model.generate_response(self.test_messages)

# Verify system role was converted to developer
call_args = mock_completion.call_args[1]
self.assertEqual(call_args["messages"][0]["role"], "developer")
self.assertEqual(call_args["messages"][1]["role"], "user")

def test_error_handling(self):
"""Test error handling."""
with patch('litellm.completion') as mock_completion:
# Test retryable error
mock_completion.side_effect = Exception("Rate limit exceeded")
response = self.model.generate_response(self.test_messages)
self.assertTrue(response["error"]["retryable"])

# Test non-retryable error
mock_completion.side_effect = Exception("Invalid API key")
response = self.model.generate_response(self.test_messages)
self.assertFalse(response["error"]["retryable"])

# Test server error
mock_completion.side_effect = Exception("Internal server error")
response = self.model.generate_response(self.test_messages)
self.assertTrue(response["error"]["retryable"])

# Verify error response format
self.assertEqual(response["content"], "")
self.assertEqual(response["total_tokens"], 0)
self.assertEqual(response["prompt_tokens"], 0)
self.assertEqual(response["completion_tokens"], 0)
self.assertEqual(response["model_used"], "openai/gpt-4")
self.assertIsNotNone(response["error"])

def test_successful_response(self):
"""Test successful response handling."""
with patch('litellm.completion') as mock_completion:
mock_response = MagicMock()
mock_response.choices = [
MagicMock(message=MagicMock(content="Hello! How can I help?"))
]
mock_response.usage = MagicMock(
total_tokens=15,
prompt_tokens=10,
completion_tokens=5
)
mock_response.model = "openai/gpt-4"
mock_completion.return_value = mock_response

response = self.model.generate_response(self.test_messages)

# Verify response format
self.assertEqual(response["content"], "Hello! How can I help?")
self.assertEqual(response["total_tokens"], 15)
self.assertEqual(response["prompt_tokens"], 10)
self.assertEqual(response["completion_tokens"], 5)
self.assertEqual(response["model_used"], "openai/gpt-4")
self.assertIsNone(response["error"])

if __name__ == '__main__':
unittest.main()
125 changes: 125 additions & 0 deletions tests/test_chat_model_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import unittest
from unittest.mock import patch, MagicMock
import litellm

from wandbot.chat.models import GeminiChatModel
from wandbot.chat.models.base import ModelError

class TestChatModel(unittest.TestCase):
def setUp(self):
self.model = GeminiChatModel(
model_name="openai/gpt-4",
fallback_models=["anthropic/claude-3", "gemini/gemini-pro"]
)
self.test_messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"}
]

def test_openai_role_conversion(self):
"""Test that system role is converted to developer for OpenAI models."""
with patch('litellm.completion') as mock_completion:
mock_response = MagicMock()
mock_response.choices = [
MagicMock(message=MagicMock(content="Hi!"))
]
mock_response.usage = MagicMock(
total_tokens=10,
prompt_tokens=8,
completion_tokens=2
)
mock_response.model = "openai/gpt-4"
mock_completion.return_value = mock_response

self.model.generate_response(self.test_messages)

# Verify system role was converted to developer
call_args = mock_completion.call_args[1]
self.assertEqual(call_args["messages"][0]["role"], "developer")
self.assertEqual(call_args["messages"][1]["role"], "user")

def test_error_handling(self):
"""Test handling of various error types."""
error_cases = [
(litellm.exceptions.RateLimitError("Rate limit exceeded"), "rate_limit", True),
(litellm.exceptions.InvalidRequestError("Invalid request"), "invalid_request", False),
(litellm.exceptions.AuthenticationError("Invalid key"), "auth_error", False),
(litellm.exceptions.APIConnectionError("Connection failed"), "connection_error", True),
(litellm.exceptions.ContextLengthExceededError("Too long"), "context_length", False),
(litellm.exceptions.ServiceUnavailableError("Service down"), "service_unavailable", True),
]

for error, expected_type, retryable in error_cases:
with patch('litellm.completion') as mock_completion:
mock_completion.side_effect = error

response = self.model.generate_response(self.test_messages)

# Verify error response format
self.assertEqual(response["content"], "")
self.assertEqual(response["total_tokens"], 0)
self.assertEqual(response["prompt_tokens"], 0)
self.assertEqual(response["completion_tokens"], 0)
self.assertEqual(response["model_used"], "openai/gpt-4")

# Verify error details
self.assertIsNotNone(response["error"])
self.assertEqual(response["error"]["type"], expected_type)
self.assertEqual(response["error"]["retryable"], retryable)

def test_model_fallback(self):
"""Test fallback to backup model on error."""
with patch('litellm.completion') as mock_completion:
# Mock successful response from fallback model
mock_response = MagicMock()
mock_response.choices = [
MagicMock(message=MagicMock(content="Fallback response"))
]
mock_response.usage = MagicMock(
total_tokens=5,
prompt_tokens=3,
completion_tokens=2
)
mock_response.model = "anthropic/claude-3"

# Configure mock to fail first call and succeed second call
mock_completion.side_effect = [
litellm.exceptions.RateLimitError("Rate limit exceeded"), # Primary model fails
mock_response # Fallback model succeeds
]

response = self.model.generate_response(self.test_messages)

# Verify fallback was successful
self.assertEqual(response["content"], "Fallback response")
self.assertEqual(response["model_used"], "anthropic/claude-3")
self.assertIsNone(response["error"])
self.assertEqual(response["total_tokens"], 5)

def test_successful_response(self):
"""Test successful response handling."""
with patch('litellm.completion') as mock_completion:
mock_response = MagicMock()
mock_response.choices = [
MagicMock(message=MagicMock(content="Hello! How can I help?"))
]
mock_response.usage = MagicMock(
total_tokens=15,
prompt_tokens=10,
completion_tokens=5
)
mock_response.model = "openai/gpt-4"
mock_completion.return_value = mock_response

response = self.model.generate_response(self.test_messages)

# Verify response format
self.assertEqual(response["content"], "Hello! How can I help?")
self.assertEqual(response["total_tokens"], 15)
self.assertEqual(response["prompt_tokens"], 10)
self.assertEqual(response["completion_tokens"], 5)
self.assertEqual(response["model_used"], "openai/gpt-4")
self.assertIsNone(response["error"])

if __name__ == '__main__':
unittest.main()

0 comments on commit c42deb6

Please sign in to comment.