Skip to content

Commit

Permalink
Update PR:
Browse files Browse the repository at this point in the history
- Remove temp files
- Update OpenAI model to use 'developer' role
- Keep existing prompt logic
  • Loading branch information
openhands-agent committed Dec 26, 2024
1 parent 0b36103 commit 0c3352d
Show file tree
Hide file tree
Showing 11 changed files with 364 additions and 523 deletions.
88 changes: 74 additions & 14 deletions src/wandbot/chat/models/anthropic_model.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,104 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
import anthropic
from anthropic._types import NOT_GIVEN

from anthropic import Anthropic

from .base import ChatModel
from .base import ChatModel, ModelError

class AnthropicChatModel(ChatModel):
def __init__(self, model_name: str = "claude-3-opus-20240229", temperature: float = 0.1):
super().__init__(model_name, temperature)
self.client = Anthropic()
"""Anthropic Claude model implementation."""

ERROR_MAPPING = {
anthropic.APIError: ("api_error", "Anthropic API error", True),
anthropic.APIConnectionError: ("connection_error", "Connection to Anthropic failed", True),
anthropic.APIResponseValidationError: ("validation_error", "Invalid API response", False),
anthropic.APIStatusError: ("status_error", "API status error", True),
anthropic.AuthenticationError: ("auth_error", "Invalid API key", False),
anthropic.BadRequestError: ("invalid_request", "Invalid request parameters", False),
anthropic.InternalServerError: ("server_error", "Anthropic server error", True),
anthropic.NotFoundError: ("not_found", "Resource not found", False),
anthropic.PermissionDeniedError: ("permission_denied", "Permission denied", False),
anthropic.RateLimitError: ("rate_limit", "Rate limit exceeded", True),
ValueError: ("invalid_input", "Invalid input parameters", False),
}

def __init__(
self,
model_name: str = "claude-3-opus-20240229",
temperature: float = 0.1,
fallback_model: Optional['ChatModel'] = None,
):
super().__init__(model_name, temperature, fallback_model)
self.client = anthropic.Anthropic()

def generate_response(
def _generate_response(
self,
messages: List[Dict[str, str]],
max_tokens: int = 1000,
) -> Dict[str, Any]:
# Convert messages to Anthropic format
# Validate input
if not messages:
return self._create_error_response(ModelError(
type="invalid_input",
message="No messages provided",
retryable=False
))

# Validate message format
for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
return self._create_error_response(ModelError(
type="invalid_input",
message="Invalid message format",
retryable=False
))
if msg["role"] not in ["system", "user", "assistant"]:
return self._create_error_response(ModelError(
type="invalid_input",
message=f"Invalid role: {msg['role']}",
retryable=False
))

# Extract system message if present
system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), None)

# Convert remaining messages to Anthropic format
anthropic_messages = []
for msg in messages:
role = msg["role"]
if role == "system":
anthropic_messages.append({"role": "assistant", "content": msg["content"]})
continue # Handle separately
elif role == "user":
anthropic_messages.append({"role": "user", "content": msg["content"]})
anthropic_messages.append({
"role": "user",
"content": [{"type": "text", "text": msg["content"]}]
})
elif role == "assistant":
anthropic_messages.append({"role": "assistant", "content": msg["content"]})
anthropic_messages.append({
"role": "assistant",
"content": [{"type": "text", "text": msg["content"]}]
})

# Create message with Anthropic's API
response = self.client.messages.create(
model=self.model_name,
messages=anthropic_messages,
system=system_msg if system_msg is not None else NOT_GIVEN,
temperature=self.temperature,
max_tokens=max_tokens,
)

# Extract content from response
content = response.content[0].text if response.content else ""

return {
"content": response.content[0].text,
"content": content,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
"error": None
}

@property
def system_role_key(self) -> str:
return "system" # Will be converted to assistant role in generate_response
"""Return the key used for system role in messages."""
return "system" # For compatibility with message format, though handled separately
82 changes: 72 additions & 10 deletions src/wandbot/chat/models/openai_model.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,95 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
from openai import OpenAI, OpenAIError, APIError, RateLimitError, APIConnectionError

from openai import OpenAI

from .base import ChatModel
from .base import ChatModel, ModelError

class OpenAIChatModel(ChatModel):
def __init__(self, model_name: str = "gpt-4-0125-preview", temperature: float = 0.1):
super().__init__(model_name, temperature)
"""OpenAI chat model implementation."""

def _map_error(self, error: Exception) -> ModelError:
"""Map OpenAI errors to standardized ModelError."""
if isinstance(error, OpenAIError):
return ModelError(
type="api_error",
message=str(error),
retryable=True
)
elif isinstance(error, ValueError):
return ModelError(
type="invalid_input",
message=str(error),
retryable=False
)
else:
return ModelError(
type="unknown_error",
message=str(error),
retryable=True
)

def __init__(
self,
model_name: str = "gpt-4-0125-preview",
temperature: float = 0.1,
fallback_model: Optional['ChatModel'] = None,
):
super().__init__(model_name, temperature, fallback_model)
self.client = OpenAI()

def generate_response(
def _generate_response(
self,
messages: List[Dict[str, str]],
max_tokens: int = 1000,
) -> Dict[str, Any]:
# Validate input
if not messages:
return self._create_error_response(ModelError(
type="invalid_input",
message="No messages provided",
retryable=False
))

# Validate message format
for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
return self._create_error_response(ModelError(
type="invalid_input",
message="Invalid message format",
retryable=False
))
if msg["role"] not in ["system", "user", "assistant"]:
return self._create_error_response(ModelError(
type="invalid_input",
message=f"Invalid role: {msg['role']}",
retryable=False
))

# Convert messages to OpenAI format, using "developer" instead of "system"
openai_messages = []
for msg in messages:
role = "developer" if msg["role"] == "system" else msg["role"]
openai_messages.append({
"role": role,
"content": msg["content"]
})

# Create completion with OpenAI's API
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
messages=openai_messages,
temperature=self.temperature,
max_tokens=max_tokens,
)

return {
"content": response.choices[0].message.content,
"total_tokens": response.usage.total_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"error": None
}

@property
def system_role_key(self) -> str:
return "system"
"""Return the key used for system role in messages."""
return "developer" # OpenAI now uses "developer" instead of "system"
128 changes: 0 additions & 128 deletions temp_gemini_full_payload.py

This file was deleted.

28 changes: 0 additions & 28 deletions temp_gemini_metadata.py

This file was deleted.

Loading

0 comments on commit 0c3352d

Please sign in to comment.