Skip to content

Commit

Permalink
Async support for Claude models
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 14, 2024
1 parent 72566a9 commit 041386e
Show file tree
Hide file tree
Showing 4 changed files with 751 additions and 18 deletions.
87 changes: 69 additions & 18 deletions llm_claude_3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from anthropic import Anthropic
from anthropic import Anthropic, AsyncAnthropic
import llm
from pydantic import Field, field_validator, model_validator
from typing import Optional, List
Expand All @@ -7,19 +7,42 @@
@llm.hookimpl
def register_models(register):
# https://docs.anthropic.com/claude/docs/models-overview
register(ClaudeMessages("claude-3-opus-20240229"))
register(ClaudeMessages("claude-3-opus-latest"), aliases=("claude-3-opus",))
register(ClaudeMessages("claude-3-sonnet-20240229"), aliases=("claude-3-sonnet",))
register(ClaudeMessages("claude-3-haiku-20240307"), aliases=("claude-3-haiku",))
register(
ClaudeMessages("claude-3-opus-20240229"),
AsyncClaudeMessages("claude-3-opus-20240229"),
),
register(
ClaudeMessages("claude-3-opus-latest"),
AsyncClaudeMessages("claude-3-opus-latest"),
aliases=("claude-3-opus",),
)
register(
ClaudeMessages("claude-3-sonnet-20240229"),
AsyncClaudeMessages("claude-3-sonnet-20240229"),
aliases=("claude-3-sonnet",),
)
register(
ClaudeMessages("claude-3-haiku-20240307"),
AsyncClaudeMessages("claude-3-haiku-20240307"),
aliases=("claude-3-haiku",),
)
# 3.5 models
register(ClaudeMessagesLong("claude-3-5-sonnet-20240620"))
register(ClaudeMessagesLong("claude-3-5-sonnet-20241022", supports_pdf=True)),
register(
ClaudeMessagesLong("claude-3-5-sonnet-20240620"),
AsyncClaudeMessagesLong("claude-3-5-sonnet-20240620"),
)
register(
ClaudeMessagesLong("claude-3-5-sonnet-20241022", supports_pdf=True),
AsyncClaudeMessagesLong("claude-3-5-sonnet-20241022", supports_pdf=True),
)
register(
ClaudeMessagesLong("claude-3-5-sonnet-latest", supports_pdf=True),
AsyncClaudeMessagesLong("claude-3-5-sonnet-latest", supports_pdf=True),
aliases=("claude-3.5-sonnet", "claude-3.5-sonnet-latest"),
)
register(
ClaudeMessagesLong("claude-3-5-haiku-latest", supports_images=False),
AsyncClaudeMessagesLong("claude-3-5-haiku-latest", supports_images=False),
aliases=("claude-3.5-haiku",),
)

Expand Down Expand Up @@ -86,7 +109,13 @@ def validate_temperature_top_p(self):
return self


class ClaudeMessages(llm.Model):
long_field = Field(
description="The maximum number of tokens to generate before stopping",
default=4_096 * 2,
)


class _Shared:
needs_key = "claude"
key_env_var = "ANTHROPIC_API_KEY"
can_stream = True
Expand Down Expand Up @@ -178,9 +207,7 @@ def build_messages(self, prompt, conversation) -> List[dict]:
messages.append({"role": "user", "content": prompt.prompt})
return messages

def execute(self, prompt, stream, response, conversation):
client = Anthropic(api_key=self.get_key())

def build_kwargs(self, prompt, conversation):
kwargs = {
"model": self.claude_model_id,
"messages": self.build_messages(prompt, conversation),
Expand All @@ -202,7 +229,17 @@ def execute(self, prompt, stream, response, conversation):

if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
return kwargs

def __str__(self):
return "Anthropic Messages: {}".format(self.model_id)


class ClaudeMessages(_Shared, llm.Model):

def execute(self, prompt, stream, response, conversation):
client = Anthropic(api_key=self.get_key())
kwargs = self.build_kwargs(prompt, conversation)
if stream:
with client.messages.stream(**kwargs) as stream:
for text in stream.text_stream:
Expand All @@ -214,13 +251,27 @@ def execute(self, prompt, stream, response, conversation):
yield completion.content[0].text
response.response_json = completion.model_dump()

def __str__(self):
return "Anthropic Messages: {}".format(self.model_id)


class ClaudeMessagesLong(ClaudeMessages):
class Options(ClaudeOptions):
max_tokens: Optional[int] = Field(
description="The maximum number of tokens to generate before stopping",
default=4_096 * 2,
)
max_tokens: Optional[int] = long_field


class AsyncClaudeMessages(_Shared, llm.AsyncModel):
async def execute(self, prompt, stream, response, conversation):
client = AsyncAnthropic(api_key=self.get_key())
kwargs = self.build_kwargs(prompt, conversation)
if stream:
async with client.messages.stream(**kwargs) as stream_obj:
async for text in stream_obj.text_stream:
yield text
response.response_json = (await stream_obj.get_final_message()).model_dump()
else:
completion = await client.messages.create(**kwargs)
yield completion.content[0].text
response.response_json = completion.model_dump()


class AsyncClaudeMessagesLong(AsyncClaudeMessages):
class Options(ClaudeOptions):
max_tokens: Optional[int] = long_field
Loading

0 comments on commit 041386e

Please sign in to comment.