-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Use LiteLLM for provider-agnostic interface - Handle OpenAI 'developer' role conversion - Simple error handling with retryable errors - Add comprehensive tests
- Loading branch information
1 parent
0c3352d
commit c42deb6
Showing
3 changed files
with
307 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
"""Chat model implementation using LiteLLM.""" | ||
from typing import List, Dict, Any | ||
import litellm | ||
|
||
class ChatModel: | ||
"""Chat model using LiteLLM for provider-agnostic interface.""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
temperature: float = 0.1, | ||
): | ||
"""Initialize chat model. | ||
Args: | ||
model_name: Name of the model to use, in format "provider/model" | ||
e.g., "openai/gpt-4", "anthropic/claude-3", "gemini/gemini-pro" | ||
temperature: Sampling temperature between 0 and 1 | ||
""" | ||
self.model_name = model_name | ||
self.temperature = temperature | ||
|
||
# Configure LiteLLM | ||
litellm.drop_params = True # Remove unsupported params | ||
litellm.set_verbose = False | ||
|
||
def generate_response( | ||
self, | ||
messages: List[Dict[str, str]], | ||
max_tokens: int = 1000, | ||
) -> Dict[str, Any]: | ||
"""Generate a response from the model. | ||
Args: | ||
messages: List of message dictionaries with 'role' and 'content' keys | ||
max_tokens: Maximum number of tokens to generate | ||
Returns: | ||
Dictionary containing: | ||
- content: The generated response text | ||
- total_tokens: Total tokens used | ||
- prompt_tokens: Tokens used in the prompt | ||
- completion_tokens: Tokens used in the completion | ||
- error: Error information if request failed | ||
- model_used: Name of the model that generated the response | ||
""" | ||
# Convert messages for OpenAI-style models | ||
if "openai" in self.model_name: | ||
messages = [ | ||
{ | ||
"role": "developer" if msg["role"] == "system" else msg["role"], | ||
"content": msg["content"] | ||
} | ||
for msg in messages | ||
] | ||
|
||
try: | ||
# Generate response | ||
response = litellm.completion( | ||
model=self.model_name, | ||
messages=messages, | ||
max_tokens=max_tokens, | ||
temperature=self.temperature, | ||
) | ||
|
||
return { | ||
"content": response.choices[0].message.content, | ||
"total_tokens": response.usage.total_tokens, | ||
"prompt_tokens": response.usage.prompt_tokens, | ||
"completion_tokens": response.usage.completion_tokens, | ||
"error": None, | ||
"model_used": response.model | ||
} | ||
|
||
except Exception as e: | ||
# Determine if error is retryable | ||
error_msg = str(e).lower() | ||
retryable = any( | ||
err_type in error_msg | ||
for err_type in ["timeout", "rate limit", "server", "connection"] | ||
) | ||
|
||
return { | ||
"content": "", | ||
"total_tokens": 0, | ||
"prompt_tokens": 0, | ||
"completion_tokens": 0, | ||
"error": { | ||
"type": type(e).__name__, | ||
"message": str(e), | ||
"retryable": retryable | ||
}, | ||
"model_used": self.model_name | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
|
||
from wandbot.chat.chat_model import ChatModel | ||
|
||
class TestChatModel(unittest.TestCase): | ||
def setUp(self): | ||
self.model = ChatModel(model_name="openai/gpt-4") | ||
self.test_messages = [ | ||
{"role": "system", "content": "You are a helpful assistant"}, | ||
{"role": "user", "content": "Hello"} | ||
] | ||
|
||
def test_openai_role_conversion(self): | ||
"""Test that system role is converted to developer for OpenAI models.""" | ||
with patch('litellm.completion') as mock_completion: | ||
mock_response = MagicMock() | ||
mock_response.choices = [ | ||
MagicMock(message=MagicMock(content="Hi!")) | ||
] | ||
mock_response.usage = MagicMock( | ||
total_tokens=10, | ||
prompt_tokens=8, | ||
completion_tokens=2 | ||
) | ||
mock_response.model = "openai/gpt-4" | ||
mock_completion.return_value = mock_response | ||
|
||
self.model.generate_response(self.test_messages) | ||
|
||
# Verify system role was converted to developer | ||
call_args = mock_completion.call_args[1] | ||
self.assertEqual(call_args["messages"][0]["role"], "developer") | ||
self.assertEqual(call_args["messages"][1]["role"], "user") | ||
|
||
def test_error_handling(self): | ||
"""Test error handling.""" | ||
with patch('litellm.completion') as mock_completion: | ||
# Test retryable error | ||
mock_completion.side_effect = Exception("Rate limit exceeded") | ||
response = self.model.generate_response(self.test_messages) | ||
self.assertTrue(response["error"]["retryable"]) | ||
|
||
# Test non-retryable error | ||
mock_completion.side_effect = Exception("Invalid API key") | ||
response = self.model.generate_response(self.test_messages) | ||
self.assertFalse(response["error"]["retryable"]) | ||
|
||
# Test server error | ||
mock_completion.side_effect = Exception("Internal server error") | ||
response = self.model.generate_response(self.test_messages) | ||
self.assertTrue(response["error"]["retryable"]) | ||
|
||
# Verify error response format | ||
self.assertEqual(response["content"], "") | ||
self.assertEqual(response["total_tokens"], 0) | ||
self.assertEqual(response["prompt_tokens"], 0) | ||
self.assertEqual(response["completion_tokens"], 0) | ||
self.assertEqual(response["model_used"], "openai/gpt-4") | ||
self.assertIsNotNone(response["error"]) | ||
|
||
def test_successful_response(self): | ||
"""Test successful response handling.""" | ||
with patch('litellm.completion') as mock_completion: | ||
mock_response = MagicMock() | ||
mock_response.choices = [ | ||
MagicMock(message=MagicMock(content="Hello! How can I help?")) | ||
] | ||
mock_response.usage = MagicMock( | ||
total_tokens=15, | ||
prompt_tokens=10, | ||
completion_tokens=5 | ||
) | ||
mock_response.model = "openai/gpt-4" | ||
mock_completion.return_value = mock_response | ||
|
||
response = self.model.generate_response(self.test_messages) | ||
|
||
# Verify response format | ||
self.assertEqual(response["content"], "Hello! How can I help?") | ||
self.assertEqual(response["total_tokens"], 15) | ||
self.assertEqual(response["prompt_tokens"], 10) | ||
self.assertEqual(response["completion_tokens"], 5) | ||
self.assertEqual(response["model_used"], "openai/gpt-4") | ||
self.assertIsNone(response["error"]) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
import litellm | ||
|
||
from wandbot.chat.models import GeminiChatModel | ||
from wandbot.chat.models.base import ModelError | ||
|
||
class TestChatModel(unittest.TestCase): | ||
def setUp(self): | ||
self.model = GeminiChatModel( | ||
model_name="openai/gpt-4", | ||
fallback_models=["anthropic/claude-3", "gemini/gemini-pro"] | ||
) | ||
self.test_messages = [ | ||
{"role": "system", "content": "You are a helpful assistant"}, | ||
{"role": "user", "content": "Hello"} | ||
] | ||
|
||
def test_openai_role_conversion(self): | ||
"""Test that system role is converted to developer for OpenAI models.""" | ||
with patch('litellm.completion') as mock_completion: | ||
mock_response = MagicMock() | ||
mock_response.choices = [ | ||
MagicMock(message=MagicMock(content="Hi!")) | ||
] | ||
mock_response.usage = MagicMock( | ||
total_tokens=10, | ||
prompt_tokens=8, | ||
completion_tokens=2 | ||
) | ||
mock_response.model = "openai/gpt-4" | ||
mock_completion.return_value = mock_response | ||
|
||
self.model.generate_response(self.test_messages) | ||
|
||
# Verify system role was converted to developer | ||
call_args = mock_completion.call_args[1] | ||
self.assertEqual(call_args["messages"][0]["role"], "developer") | ||
self.assertEqual(call_args["messages"][1]["role"], "user") | ||
|
||
def test_error_handling(self): | ||
"""Test handling of various error types.""" | ||
error_cases = [ | ||
(litellm.exceptions.RateLimitError("Rate limit exceeded"), "rate_limit", True), | ||
(litellm.exceptions.InvalidRequestError("Invalid request"), "invalid_request", False), | ||
(litellm.exceptions.AuthenticationError("Invalid key"), "auth_error", False), | ||
(litellm.exceptions.APIConnectionError("Connection failed"), "connection_error", True), | ||
(litellm.exceptions.ContextLengthExceededError("Too long"), "context_length", False), | ||
(litellm.exceptions.ServiceUnavailableError("Service down"), "service_unavailable", True), | ||
] | ||
|
||
for error, expected_type, retryable in error_cases: | ||
with patch('litellm.completion') as mock_completion: | ||
mock_completion.side_effect = error | ||
|
||
response = self.model.generate_response(self.test_messages) | ||
|
||
# Verify error response format | ||
self.assertEqual(response["content"], "") | ||
self.assertEqual(response["total_tokens"], 0) | ||
self.assertEqual(response["prompt_tokens"], 0) | ||
self.assertEqual(response["completion_tokens"], 0) | ||
self.assertEqual(response["model_used"], "openai/gpt-4") | ||
|
||
# Verify error details | ||
self.assertIsNotNone(response["error"]) | ||
self.assertEqual(response["error"]["type"], expected_type) | ||
self.assertEqual(response["error"]["retryable"], retryable) | ||
|
||
def test_model_fallback(self): | ||
"""Test fallback to backup model on error.""" | ||
with patch('litellm.completion') as mock_completion: | ||
# Mock successful response from fallback model | ||
mock_response = MagicMock() | ||
mock_response.choices = [ | ||
MagicMock(message=MagicMock(content="Fallback response")) | ||
] | ||
mock_response.usage = MagicMock( | ||
total_tokens=5, | ||
prompt_tokens=3, | ||
completion_tokens=2 | ||
) | ||
mock_response.model = "anthropic/claude-3" | ||
|
||
# Configure mock to fail first call and succeed second call | ||
mock_completion.side_effect = [ | ||
litellm.exceptions.RateLimitError("Rate limit exceeded"), # Primary model fails | ||
mock_response # Fallback model succeeds | ||
] | ||
|
||
response = self.model.generate_response(self.test_messages) | ||
|
||
# Verify fallback was successful | ||
self.assertEqual(response["content"], "Fallback response") | ||
self.assertEqual(response["model_used"], "anthropic/claude-3") | ||
self.assertIsNone(response["error"]) | ||
self.assertEqual(response["total_tokens"], 5) | ||
|
||
def test_successful_response(self): | ||
"""Test successful response handling.""" | ||
with patch('litellm.completion') as mock_completion: | ||
mock_response = MagicMock() | ||
mock_response.choices = [ | ||
MagicMock(message=MagicMock(content="Hello! How can I help?")) | ||
] | ||
mock_response.usage = MagicMock( | ||
total_tokens=15, | ||
prompt_tokens=10, | ||
completion_tokens=5 | ||
) | ||
mock_response.model = "openai/gpt-4" | ||
mock_completion.return_value = mock_response | ||
|
||
response = self.model.generate_response(self.test_messages) | ||
|
||
# Verify response format | ||
self.assertEqual(response["content"], "Hello! How can I help?") | ||
self.assertEqual(response["total_tokens"], 15) | ||
self.assertEqual(response["prompt_tokens"], 10) | ||
self.assertEqual(response["completion_tokens"], 5) | ||
self.assertEqual(response["model_used"], "openai/gpt-4") | ||
self.assertIsNone(response["error"]) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |