diff --git a/src/wandbot/chat/chat_model.py b/src/wandbot/chat/chat_model.py new file mode 100644 index 0000000..c88d06d --- /dev/null +++ b/src/wandbot/chat/chat_model.py @@ -0,0 +1,94 @@ +"""Chat model implementation using LiteLLM.""" +from typing import List, Dict, Any +import litellm + +class ChatModel: + """Chat model using LiteLLM for provider-agnostic interface.""" + + def __init__( + self, + model_name: str, + temperature: float = 0.1, + ): + """Initialize chat model. + + Args: + model_name: Name of the model to use, in format "provider/model" + e.g., "openai/gpt-4", "anthropic/claude-3", "gemini/gemini-pro" + temperature: Sampling temperature between 0 and 1 + """ + self.model_name = model_name + self.temperature = temperature + + # Configure LiteLLM + litellm.drop_params = True # Remove unsupported params + litellm.set_verbose = False + + def generate_response( + self, + messages: List[Dict[str, str]], + max_tokens: int = 1000, + ) -> Dict[str, Any]: + """Generate a response from the model. + + Args: + messages: List of message dictionaries with 'role' and 'content' keys + max_tokens: Maximum number of tokens to generate + + Returns: + Dictionary containing: + - content: The generated response text + - total_tokens: Total tokens used + - prompt_tokens: Tokens used in the prompt + - completion_tokens: Tokens used in the completion + - error: Error information if request failed + - model_used: Name of the model that generated the response + """ + # Convert messages for OpenAI-style models + if "openai" in self.model_name: + messages = [ + { + "role": "developer" if msg["role"] == "system" else msg["role"], + "content": msg["content"] + } + for msg in messages + ] + + try: + # Generate response + response = litellm.completion( + model=self.model_name, + messages=messages, + max_tokens=max_tokens, + temperature=self.temperature, + ) + + return { + "content": response.choices[0].message.content, + "total_tokens": response.usage.total_tokens, + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "error": None, + "model_used": response.model + } + + except Exception as e: + # Determine if error is retryable + error_msg = str(e).lower() + retryable = any( + err_type in error_msg + for err_type in ["timeout", "rate limit", "server", "connection"] + ) + + return { + "content": "", + "total_tokens": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "error": { + "type": type(e).__name__, + "message": str(e), + "retryable": retryable + }, + "model_used": self.model_name + } \ No newline at end of file diff --git a/tests/test_chat_model.py b/tests/test_chat_model.py new file mode 100644 index 0000000..11830dd --- /dev/null +++ b/tests/test_chat_model.py @@ -0,0 +1,88 @@ +import unittest +from unittest.mock import patch, MagicMock + +from wandbot.chat.chat_model import ChatModel + +class TestChatModel(unittest.TestCase): + def setUp(self): + self.model = ChatModel(model_name="openai/gpt-4") + self.test_messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"} + ] + + def test_openai_role_conversion(self): + """Test that system role is converted to developer for OpenAI models.""" + with patch('litellm.completion') as mock_completion: + mock_response = MagicMock() + mock_response.choices = [ + MagicMock(message=MagicMock(content="Hi!")) + ] + mock_response.usage = MagicMock( + total_tokens=10, + prompt_tokens=8, + completion_tokens=2 + ) + mock_response.model = "openai/gpt-4" + mock_completion.return_value = mock_response + + self.model.generate_response(self.test_messages) + + # Verify system role was converted to developer + call_args = mock_completion.call_args[1] + self.assertEqual(call_args["messages"][0]["role"], "developer") + self.assertEqual(call_args["messages"][1]["role"], "user") + + def test_error_handling(self): + """Test error handling.""" + with patch('litellm.completion') as mock_completion: + # Test retryable error + mock_completion.side_effect = Exception("Rate limit exceeded") + response = self.model.generate_response(self.test_messages) + self.assertTrue(response["error"]["retryable"]) + + # Test non-retryable error + mock_completion.side_effect = Exception("Invalid API key") + response = self.model.generate_response(self.test_messages) + self.assertFalse(response["error"]["retryable"]) + + # Test server error + mock_completion.side_effect = Exception("Internal server error") + response = self.model.generate_response(self.test_messages) + self.assertTrue(response["error"]["retryable"]) + + # Verify error response format + self.assertEqual(response["content"], "") + self.assertEqual(response["total_tokens"], 0) + self.assertEqual(response["prompt_tokens"], 0) + self.assertEqual(response["completion_tokens"], 0) + self.assertEqual(response["model_used"], "openai/gpt-4") + self.assertIsNotNone(response["error"]) + + def test_successful_response(self): + """Test successful response handling.""" + with patch('litellm.completion') as mock_completion: + mock_response = MagicMock() + mock_response.choices = [ + MagicMock(message=MagicMock(content="Hello! How can I help?")) + ] + mock_response.usage = MagicMock( + total_tokens=15, + prompt_tokens=10, + completion_tokens=5 + ) + mock_response.model = "openai/gpt-4" + mock_completion.return_value = mock_response + + response = self.model.generate_response(self.test_messages) + + # Verify response format + self.assertEqual(response["content"], "Hello! How can I help?") + self.assertEqual(response["total_tokens"], 15) + self.assertEqual(response["prompt_tokens"], 10) + self.assertEqual(response["completion_tokens"], 5) + self.assertEqual(response["model_used"], "openai/gpt-4") + self.assertIsNone(response["error"]) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_chat_model_new.py b/tests/test_chat_model_new.py new file mode 100644 index 0000000..6b49f23 --- /dev/null +++ b/tests/test_chat_model_new.py @@ -0,0 +1,125 @@ +import unittest +from unittest.mock import patch, MagicMock +import litellm + +from wandbot.chat.models import GeminiChatModel +from wandbot.chat.models.base import ModelError + +class TestChatModel(unittest.TestCase): + def setUp(self): + self.model = GeminiChatModel( + model_name="openai/gpt-4", + fallback_models=["anthropic/claude-3", "gemini/gemini-pro"] + ) + self.test_messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"} + ] + + def test_openai_role_conversion(self): + """Test that system role is converted to developer for OpenAI models.""" + with patch('litellm.completion') as mock_completion: + mock_response = MagicMock() + mock_response.choices = [ + MagicMock(message=MagicMock(content="Hi!")) + ] + mock_response.usage = MagicMock( + total_tokens=10, + prompt_tokens=8, + completion_tokens=2 + ) + mock_response.model = "openai/gpt-4" + mock_completion.return_value = mock_response + + self.model.generate_response(self.test_messages) + + # Verify system role was converted to developer + call_args = mock_completion.call_args[1] + self.assertEqual(call_args["messages"][0]["role"], "developer") + self.assertEqual(call_args["messages"][1]["role"], "user") + + def test_error_handling(self): + """Test handling of various error types.""" + error_cases = [ + (litellm.exceptions.RateLimitError("Rate limit exceeded"), "rate_limit", True), + (litellm.exceptions.InvalidRequestError("Invalid request"), "invalid_request", False), + (litellm.exceptions.AuthenticationError("Invalid key"), "auth_error", False), + (litellm.exceptions.APIConnectionError("Connection failed"), "connection_error", True), + (litellm.exceptions.ContextLengthExceededError("Too long"), "context_length", False), + (litellm.exceptions.ServiceUnavailableError("Service down"), "service_unavailable", True), + ] + + for error, expected_type, retryable in error_cases: + with patch('litellm.completion') as mock_completion: + mock_completion.side_effect = error + + response = self.model.generate_response(self.test_messages) + + # Verify error response format + self.assertEqual(response["content"], "") + self.assertEqual(response["total_tokens"], 0) + self.assertEqual(response["prompt_tokens"], 0) + self.assertEqual(response["completion_tokens"], 0) + self.assertEqual(response["model_used"], "openai/gpt-4") + + # Verify error details + self.assertIsNotNone(response["error"]) + self.assertEqual(response["error"]["type"], expected_type) + self.assertEqual(response["error"]["retryable"], retryable) + + def test_model_fallback(self): + """Test fallback to backup model on error.""" + with patch('litellm.completion') as mock_completion: + # Mock successful response from fallback model + mock_response = MagicMock() + mock_response.choices = [ + MagicMock(message=MagicMock(content="Fallback response")) + ] + mock_response.usage = MagicMock( + total_tokens=5, + prompt_tokens=3, + completion_tokens=2 + ) + mock_response.model = "anthropic/claude-3" + + # Configure mock to fail first call and succeed second call + mock_completion.side_effect = [ + litellm.exceptions.RateLimitError("Rate limit exceeded"), # Primary model fails + mock_response # Fallback model succeeds + ] + + response = self.model.generate_response(self.test_messages) + + # Verify fallback was successful + self.assertEqual(response["content"], "Fallback response") + self.assertEqual(response["model_used"], "anthropic/claude-3") + self.assertIsNone(response["error"]) + self.assertEqual(response["total_tokens"], 5) + + def test_successful_response(self): + """Test successful response handling.""" + with patch('litellm.completion') as mock_completion: + mock_response = MagicMock() + mock_response.choices = [ + MagicMock(message=MagicMock(content="Hello! How can I help?")) + ] + mock_response.usage = MagicMock( + total_tokens=15, + prompt_tokens=10, + completion_tokens=5 + ) + mock_response.model = "openai/gpt-4" + mock_completion.return_value = mock_response + + response = self.model.generate_response(self.test_messages) + + # Verify response format + self.assertEqual(response["content"], "Hello! How can I help?") + self.assertEqual(response["total_tokens"], 15) + self.assertEqual(response["prompt_tokens"], 10) + self.assertEqual(response["completion_tokens"], 5) + self.assertEqual(response["model_used"], "openai/gpt-4") + self.assertIsNone(response["error"]) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file