Skip to content

Commit

Permalink
Add Claude 3 model (#2432)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Mar 5, 2024
1 parent 3e2c6a1 commit 5b1cfdc
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 10 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ install_requires=

# Basic Scenarios
datasets~=2.15
pyarrow>=11.0.0, # Pinned transitive dependency for datasets; workaround for #1026
pyarrow>=11.0.0 # Pinned transitive dependency for datasets; workaround for #1026
pyarrow-hotfix~=0.6 # Hotfix for CVE-2023-47248

# Basic metrics
Expand Down Expand Up @@ -121,7 +121,7 @@ amazon =
botocore~=1.31.57

anthropic =
anthropic~=0.2.5
anthropic~=0.17
websocket-client~=1.3.2 # For legacy stanford-online-all-v4-s3

mistral =
Expand Down
124 changes: 117 additions & 7 deletions src/helm/clients/anthropic_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TypedDict, cast
import json
import requests
import time
Expand All @@ -20,16 +20,28 @@
TokenizationRequest,
TokenizationRequestResult,
)
from helm.proxy.retry import NonRetriableException
from helm.tokenizers.tokenizer import Tokenizer
from .client import CachingClient, truncate_sequence
from helm.clients.client import CachingClient, truncate_sequence, truncate_and_tokenize_response_text

try:
import anthropic
from anthropic import Anthropic
from anthropic.types import MessageParam
import websocket
except ModuleNotFoundError as e:
handle_module_not_found_error(e, ["anthropic"])


class AnthropicCompletionRequest(TypedDict):
prompt: str
stop_sequences: List[str]
model: str
max_tokens_to_sample: int
temperature: float
top_p: float
top_k: int


class AnthropicClient(CachingClient):
"""
Client for the Anthropic models (https://arxiv.org/abs/2204.05862).
Expand Down Expand Up @@ -63,12 +75,12 @@ def __init__(
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer_name
self.api_key: Optional[str] = api_key
self._client = anthropic.Client(api_key) if api_key else None
self.client = Anthropic(api_key=api_key)

def _send_request(self, raw_request: Dict[str, Any]) -> Dict[str, Any]:
def _send_request(self, raw_request: AnthropicCompletionRequest) -> Dict[str, Any]:
if self.api_key is None:
raise Exception("API key is not set. Please set it in the HELM config file.")
result = self._client.completion(**raw_request)
result = self.client.completions.create(**raw_request).model_dump()
assert "error" not in result, f"Request failed with error: {result['error']}"
return result

Expand Down Expand Up @@ -103,7 +115,7 @@ def make_request(self, request: Request) -> RequestResult:
if request.max_tokens == 0 and not request.echo_prompt:
raise ValueError("echo_prompt must be True when max_tokens=0.")

raw_request = {
raw_request: AnthropicCompletionRequest = {
"prompt": request.prompt,
"stop_sequences": request.stop_sequences,
"model": request.model_engine,
Expand Down Expand Up @@ -190,6 +202,104 @@ def do_it():
)


class AnthropicMessagesRequest(TypedDict, total=False):
messages: List[MessageParam]
model: str
stop_sequences: List[str]
system: str
max_tokens: int
temperature: float
top_k: int
top_p: float


class AnthropicMessagesRequestError(NonRetriableException):
pass


class AnthropicMessagesResponseError(Exception):
pass


class AnthropicMessagesClient(CachingClient):
# Source: https://docs.anthropic.com/claude/docs/models-overview
MAX_OUTPUT_TOKENS = 4096

def __init__(
self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig, api_key: Optional[str] = None
):
super().__init__(cache_config=cache_config)
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer_name
self.client = Anthropic(api_key=api_key)
self.api_key: Optional[str] = api_key

def make_request(self, request: Request) -> RequestResult:
if request.max_tokens > AnthropicMessagesClient.MAX_OUTPUT_TOKENS:
raise AnthropicMessagesRequestError(
f"Request.max_tokens must be <= {AnthropicMessagesClient.MAX_OUTPUT_TOKENS}"
)

messages: List[MessageParam] = []
system_message: Optional[MessageParam] = None
if request.messages and request.prompt:
raise AnthropicMessagesRequestError("Exactly one of Request.messages and Request.prompt should be set")
if request.messages:
messages = cast(List[MessageParam], request.messages)
if messages[0]["role"] == "system":
system_message = messages[0]
messages = messages[1:]
else:
messages = [{"role": "user", "content": request.prompt}]

raw_request: AnthropicMessagesRequest = {
"messages": messages,
"model": request.model_engine,
"stop_sequences": request.stop_sequences,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"top_k": request.top_k_per_token,
}
if system_message is not None:
raw_request["system"] = cast(str, system_message["content"])
completions: List[Sequence] = []

# `num_completions` is not supported, so instead make `num_completions` separate requests.
for completion_index in range(request.num_completions):

def do_it() -> Dict[str, Any]:
result = self.client.messages.create(**raw_request).model_dump()
if "content" not in result or not result["content"]:
raise AnthropicMessagesResponseError(f"Anthropic response has empty content: {result}")
elif "text" not in result["content"][0]:
raise AnthropicMessagesResponseError(f"Anthropic response has non-text content: {result}")
return result

cache_key = CachingClient.make_cache_key(
{
"completion_index": completion_index,
**raw_request,
},
request,
)

response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
completion = truncate_and_tokenize_response_text(
response["content"][0]["text"], request, self.tokenizer, self.tokenizer_name, original_finish_reason=""
)
completions.append(completion)

return RequestResult(
success=True,
cached=cached,
request_time=response["request_time"],
request_datetime=response["request_datetime"],
completions=completions,
embedding=[],
)


class AnthropicRequestError(Exception):
pass

Expand Down
14 changes: 14 additions & 0 deletions src/helm/config/model_deployments.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,20 @@ model_deployments:
client_spec:
class_name: "helm.clients.anthropic_client.AnthropicClient"

- name: anthropic/claude-3-sonnet-20240229
model_name: anthropic/claude-3-sonnet-20240229
tokenizer_name: anthropic/claude
max_sequence_length: 200000
client_spec:
class_name: "helm.clients.anthropic_client.AnthropicMessagesClient"

- name: anthropic/claude-3-opus-20240229
model_name: anthropic/claude-3-opus-20240229
tokenizer_name: anthropic/claude
max_sequence_length: 200000
client_spec:
class_name: "helm.clients.anthropic_client.AnthropicMessagesClient"

- name: anthropic/stanford-online-all-v4-s3
deprecated: true # Closed model, not accessible via API
model_name: anthropic/stanford-online-all-v4-s3
Expand Down
16 changes: 16 additions & 0 deletions src/helm/config/model_metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,22 @@ models:
release_date: 2023-11-21
tags: [ANTHROPIC_CLAUDE_2_MODEL_TAG, TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, ABLATION_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

- name: anthropic/claude-3-sonnet-20240229
display_name: Claude 3 Sonnet (20240229)
description: TBD
creator_organization_name: Anthropic
access: limited
release_date: 2024-03-04
tags: [TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

- name: anthropic/claude-3-opus-20240229
display_name: Claude 3 Opus (20240229)
description: TBD
creator_organization_name: Anthropic
access: limited
release_date: 2024-03-04
tags: [TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

# DEPRECATED: Please do not use.
- name: anthropic/stanford-online-all-v4-s3
display_name: Anthropic-LM v4-s3 (52B)
Expand Down
2 changes: 1 addition & 1 deletion src/helm/tokenizers/anthropic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
with AnthropicTokenizer.LOCK:
self._tokenizer: PreTrainedTokenizerBase = PreTrainedTokenizerFast(
tokenizer_object=anthropic.get_tokenizer()
tokenizer_object=anthropic.Anthropic().get_tokenizer()
)

def _tokenize_do_it(self, request: Dict[str, Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit 5b1cfdc

Please sign in to comment.