diff --git a/llm_claude_3.py b/llm_claude_3.py index 40ba992..1b132ae 100644 --- a/llm_claude_3.py +++ b/llm_claude_3.py @@ -10,7 +10,17 @@ def register_models(register): register(ClaudeMessages("claude-3-opus-20240229"), aliases=("claude-3-opus",)) register(ClaudeMessages("claude-3-sonnet-20240229"), aliases=("claude-3-sonnet",)) register(ClaudeMessages("claude-3-haiku-20240307"), aliases=("claude-3-haiku",)) - register(ClaudeMessages("claude-3-5-sonnet-20240620"), aliases=("claude-3.5-sonnet",)) + register( + ClaudeMessages("claude-3-5-sonnet-20240620"), aliases=("claude-3.5-sonnet",) + ) + register( + ClaudeMessages( + "claude-3-5-sonnet-20240620-long", + claude_model_id="claude-3-5-sonnet-20240620", + extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"}, + ), + aliases=("claude-3.5-sonnet-long",), + ) class ClaudeOptions(llm.Options): @@ -81,8 +91,10 @@ class ClaudeMessages(llm.Model): class Options(ClaudeOptions): ... - def __init__(self, model_id): + def __init__(self, model_id, claude_model_id=None, extra_headers=None): self.model_id = model_id + self.claude_model_id = claude_model_id or model_id + self.extra_headers = extra_headers def build_messages(self, prompt, conversation) -> List[dict]: messages = [] @@ -104,7 +116,7 @@ def execute(self, prompt, stream, response, conversation): client = Anthropic(api_key=self.get_key()) kwargs = { - "model": self.model_id, + "model": self.claude_model_id, "messages": self.build_messages(prompt, conversation), "max_tokens": prompt.options.max_tokens, } @@ -122,7 +134,9 @@ def execute(self, prompt, stream, response, conversation): if prompt.system: kwargs["system"] = prompt.system - usage = None + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + if stream: with client.messages.stream(**kwargs) as stream: for text in stream.text_stream: