-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Remove temp files - Update OpenAI model to use 'developer' role - Keep existing prompt logic
- Loading branch information
1 parent
0b36103
commit 0c3352d
Showing
11 changed files
with
364 additions
and
523 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,104 @@ | ||
from typing import List, Dict, Any | ||
from typing import List, Dict, Any, Optional | ||
import anthropic | ||
from anthropic._types import NOT_GIVEN | ||
|
||
from anthropic import Anthropic | ||
|
||
from .base import ChatModel | ||
from .base import ChatModel, ModelError | ||
|
||
class AnthropicChatModel(ChatModel): | ||
def __init__(self, model_name: str = "claude-3-opus-20240229", temperature: float = 0.1): | ||
super().__init__(model_name, temperature) | ||
self.client = Anthropic() | ||
"""Anthropic Claude model implementation.""" | ||
|
||
ERROR_MAPPING = { | ||
anthropic.APIError: ("api_error", "Anthropic API error", True), | ||
anthropic.APIConnectionError: ("connection_error", "Connection to Anthropic failed", True), | ||
anthropic.APIResponseValidationError: ("validation_error", "Invalid API response", False), | ||
anthropic.APIStatusError: ("status_error", "API status error", True), | ||
anthropic.AuthenticationError: ("auth_error", "Invalid API key", False), | ||
anthropic.BadRequestError: ("invalid_request", "Invalid request parameters", False), | ||
anthropic.InternalServerError: ("server_error", "Anthropic server error", True), | ||
anthropic.NotFoundError: ("not_found", "Resource not found", False), | ||
anthropic.PermissionDeniedError: ("permission_denied", "Permission denied", False), | ||
anthropic.RateLimitError: ("rate_limit", "Rate limit exceeded", True), | ||
ValueError: ("invalid_input", "Invalid input parameters", False), | ||
} | ||
|
||
def __init__( | ||
self, | ||
model_name: str = "claude-3-opus-20240229", | ||
temperature: float = 0.1, | ||
fallback_model: Optional['ChatModel'] = None, | ||
): | ||
super().__init__(model_name, temperature, fallback_model) | ||
self.client = anthropic.Anthropic() | ||
|
||
def generate_response( | ||
def _generate_response( | ||
self, | ||
messages: List[Dict[str, str]], | ||
max_tokens: int = 1000, | ||
) -> Dict[str, Any]: | ||
# Convert messages to Anthropic format | ||
# Validate input | ||
if not messages: | ||
return self._create_error_response(ModelError( | ||
type="invalid_input", | ||
message="No messages provided", | ||
retryable=False | ||
)) | ||
|
||
# Validate message format | ||
for msg in messages: | ||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: | ||
return self._create_error_response(ModelError( | ||
type="invalid_input", | ||
message="Invalid message format", | ||
retryable=False | ||
)) | ||
if msg["role"] not in ["system", "user", "assistant"]: | ||
return self._create_error_response(ModelError( | ||
type="invalid_input", | ||
message=f"Invalid role: {msg['role']}", | ||
retryable=False | ||
)) | ||
|
||
# Extract system message if present | ||
system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), None) | ||
|
||
# Convert remaining messages to Anthropic format | ||
anthropic_messages = [] | ||
for msg in messages: | ||
role = msg["role"] | ||
if role == "system": | ||
anthropic_messages.append({"role": "assistant", "content": msg["content"]}) | ||
continue # Handle separately | ||
elif role == "user": | ||
anthropic_messages.append({"role": "user", "content": msg["content"]}) | ||
anthropic_messages.append({ | ||
"role": "user", | ||
"content": [{"type": "text", "text": msg["content"]}] | ||
}) | ||
elif role == "assistant": | ||
anthropic_messages.append({"role": "assistant", "content": msg["content"]}) | ||
anthropic_messages.append({ | ||
"role": "assistant", | ||
"content": [{"type": "text", "text": msg["content"]}] | ||
}) | ||
|
||
# Create message with Anthropic's API | ||
response = self.client.messages.create( | ||
model=self.model_name, | ||
messages=anthropic_messages, | ||
system=system_msg if system_msg is not None else NOT_GIVEN, | ||
temperature=self.temperature, | ||
max_tokens=max_tokens, | ||
) | ||
|
||
# Extract content from response | ||
content = response.content[0].text if response.content else "" | ||
|
||
return { | ||
"content": response.content[0].text, | ||
"content": content, | ||
"total_tokens": response.usage.input_tokens + response.usage.output_tokens, | ||
"prompt_tokens": response.usage.input_tokens, | ||
"completion_tokens": response.usage.output_tokens, | ||
"error": None | ||
} | ||
|
||
@property | ||
def system_role_key(self) -> str: | ||
return "system" # Will be converted to assistant role in generate_response | ||
"""Return the key used for system role in messages.""" | ||
return "system" # For compatibility with message format, though handled separately |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,95 @@ | ||
from typing import List, Dict, Any | ||
from typing import List, Dict, Any, Optional | ||
from openai import OpenAI, OpenAIError, APIError, RateLimitError, APIConnectionError | ||
|
||
from openai import OpenAI | ||
|
||
from .base import ChatModel | ||
from .base import ChatModel, ModelError | ||
|
||
class OpenAIChatModel(ChatModel): | ||
def __init__(self, model_name: str = "gpt-4-0125-preview", temperature: float = 0.1): | ||
super().__init__(model_name, temperature) | ||
"""OpenAI chat model implementation.""" | ||
|
||
def _map_error(self, error: Exception) -> ModelError: | ||
"""Map OpenAI errors to standardized ModelError.""" | ||
if isinstance(error, OpenAIError): | ||
return ModelError( | ||
type="api_error", | ||
message=str(error), | ||
retryable=True | ||
) | ||
elif isinstance(error, ValueError): | ||
return ModelError( | ||
type="invalid_input", | ||
message=str(error), | ||
retryable=False | ||
) | ||
else: | ||
return ModelError( | ||
type="unknown_error", | ||
message=str(error), | ||
retryable=True | ||
) | ||
|
||
def __init__( | ||
self, | ||
model_name: str = "gpt-4-0125-preview", | ||
temperature: float = 0.1, | ||
fallback_model: Optional['ChatModel'] = None, | ||
): | ||
super().__init__(model_name, temperature, fallback_model) | ||
self.client = OpenAI() | ||
|
||
def generate_response( | ||
def _generate_response( | ||
self, | ||
messages: List[Dict[str, str]], | ||
max_tokens: int = 1000, | ||
) -> Dict[str, Any]: | ||
# Validate input | ||
if not messages: | ||
return self._create_error_response(ModelError( | ||
type="invalid_input", | ||
message="No messages provided", | ||
retryable=False | ||
)) | ||
|
||
# Validate message format | ||
for msg in messages: | ||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: | ||
return self._create_error_response(ModelError( | ||
type="invalid_input", | ||
message="Invalid message format", | ||
retryable=False | ||
)) | ||
if msg["role"] not in ["system", "user", "assistant"]: | ||
return self._create_error_response(ModelError( | ||
type="invalid_input", | ||
message=f"Invalid role: {msg['role']}", | ||
retryable=False | ||
)) | ||
|
||
# Convert messages to OpenAI format, using "developer" instead of "system" | ||
openai_messages = [] | ||
for msg in messages: | ||
role = "developer" if msg["role"] == "system" else msg["role"] | ||
openai_messages.append({ | ||
"role": role, | ||
"content": msg["content"] | ||
}) | ||
|
||
# Create completion with OpenAI's API | ||
response = self.client.chat.completions.create( | ||
model=self.model_name, | ||
messages=messages, | ||
messages=openai_messages, | ||
temperature=self.temperature, | ||
max_tokens=max_tokens, | ||
) | ||
|
||
return { | ||
"content": response.choices[0].message.content, | ||
"total_tokens": response.usage.total_tokens, | ||
"prompt_tokens": response.usage.prompt_tokens, | ||
"completion_tokens": response.usage.completion_tokens, | ||
"error": None | ||
} | ||
|
||
@property | ||
def system_role_key(self) -> str: | ||
return "system" | ||
"""Return the key used for system role in messages.""" | ||
return "developer" # OpenAI now uses "developer" instead of "system" |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.