-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(llm): add Anthropic LLM client #31
Closed
Closed
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
class DbAllyError(Exception): | ||
""" | ||
Base class for all exceptions raised by db-ally. | ||
""" | ||
|
||
|
||
class LLMError(DbAllyError): | ||
""" | ||
Base class for all exceptions raised by the LLMClient. | ||
""" | ||
|
||
def __init__(self, message: str) -> None: | ||
super().__init__(message) | ||
self.message = message | ||
|
||
|
||
class LLMConnectionError(LLMError): | ||
""" | ||
Raised when there is an error connecting to the LLM API. | ||
""" | ||
|
||
def __init__(self, message: str = "Connection error.") -> None: | ||
super().__init__(message) | ||
|
||
|
||
class LLMStatusError(LLMError): | ||
""" | ||
Raised when an API response has a status code of 4xx or 5xx. | ||
""" | ||
|
||
def __init__(self, message: str, status_code: int) -> None: | ||
super().__init__(message) | ||
self.status_code = status_code | ||
|
||
|
||
class LLMResponseError(LLMError): | ||
""" | ||
Raised when an API response has an invalid schema. | ||
""" | ||
|
||
def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: | ||
super().__init__(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, ClassVar, Dict, List, Optional, Union | ||
|
||
from anthropic import NOT_GIVEN as ANTHROPIC_NOT_GIVEN | ||
from anthropic import APIConnectionError, APIResponseValidationError, APIStatusError | ||
from anthropic import NotGiven as AnthropicNotGiven | ||
|
||
from dbally.data_models.audit import LLMEvent | ||
from dbally.llm_client.base import LLMClient, LLMOptions | ||
from dbally.prompts.common_validation_utils import extract_system_prompt | ||
from dbally.prompts.prompt_builder import ChatFormat | ||
|
||
from .._exceptions import LLMConnectionError, LLMResponseError, LLMStatusError | ||
from .._types import NOT_GIVEN, NotGiven | ||
|
||
|
||
@dataclass | ||
class AnthropicOptions(LLMOptions): | ||
""" | ||
Dataclass that represents all available LLM call options for the Anthropic API. Each of them is | ||
described in the [Anthropic API documentation](https://docs.anthropic.com/en/api/messages). | ||
""" | ||
|
||
_not_given: ClassVar[Optional[AnthropicNotGiven]] = ANTHROPIC_NOT_GIVEN | ||
|
||
max_tokens: Union[int, NotGiven] = NOT_GIVEN | ||
stop_sequences: Union[Optional[List[str]], NotGiven] = NOT_GIVEN | ||
temperature: Union[Optional[float], NotGiven] = NOT_GIVEN | ||
top_k: Union[Optional[int], NotGiven] = NOT_GIVEN | ||
top_p: Union[Optional[float], NotGiven] = NOT_GIVEN | ||
|
||
def dict(self) -> Dict[str, Any]: | ||
""" | ||
Returns a dictionary representation of the LLMOptions instance. | ||
If a value is None, it will be replaced with a provider-specific not-given sentinel, | ||
except for max_tokens, which is set to 4096 if not provided. | ||
|
||
Returns: | ||
A dictionary representation of the LLMOptions instance. | ||
""" | ||
options = super().dict() | ||
|
||
# Anthropic API requires max_tokens to be set | ||
if isinstance(options["max_tokens"], AnthropicNotGiven) or options["max_tokens"] is None: | ||
options["max_tokens"] = 4096 | ||
|
||
return options | ||
|
||
|
||
class AnthropicClient(LLMClient[AnthropicOptions]): | ||
""" | ||
`AnthropicClient` is a class designed to interact with Anthropic's language model (LLM) endpoints, | ||
particularly for the Claude models. | ||
|
||
Args: | ||
model_name: Name of the [Anthropic's model](https://docs.anthropic.com/claude/docs/models-overview) to be used,\ | ||
default is "claude-3-opus-20240229". | ||
api_key: Anthropic's API key. If None ANTHROPIC_API_KEY environment variable will be used. | ||
default_options: Default options to be used in the LLM calls. | ||
""" | ||
|
||
_options_cls = AnthropicOptions | ||
|
||
def __init__( | ||
self, | ||
model_name: str = "claude-3-opus-20240229", | ||
api_key: Optional[str] = None, | ||
default_options: Optional[AnthropicOptions] = None, | ||
) -> None: | ||
try: | ||
from anthropic import AsyncAnthropic # pylint: disable=import-outside-toplevel | ||
except ImportError as exc: | ||
raise ImportError("You need to install anthropic package to use Claude models") from exc | ||
|
||
super().__init__(model_name=model_name, default_options=default_options) | ||
self._client = AsyncAnthropic(api_key=api_key) | ||
|
||
async def call( | ||
self, | ||
prompt: Union[str, ChatFormat], | ||
response_format: Optional[Dict[str, str]], | ||
options: AnthropicOptions, | ||
event: LLMEvent, | ||
) -> str: | ||
""" | ||
Calls the Anthropic API endpoint. | ||
|
||
Args: | ||
prompt: Prompt as an Anthropic client style list. | ||
response_format: Optional argument used in the OpenAI API - used to force the json output. | ||
options: Additional settings used by the LLM. | ||
event: Container with the prompt, LLM response and call metrics. | ||
|
||
Returns: | ||
Response string from LLM. | ||
|
||
Raises: | ||
LLMConnectionError: If there was an error connecting to the LLM API. | ||
LLMStatusError: If the LLM API returned an error status. | ||
LLMResponseError: If the LLM API returned an invalid response. | ||
""" | ||
prompt, system = extract_system_prompt(prompt) | ||
|
||
try: | ||
response = await self._client.messages.create( | ||
messages=prompt, | ||
model=self.model_name, | ||
system=system, | ||
**options.dict(), # type: ignore | ||
) | ||
except APIConnectionError as exc: | ||
raise LLMConnectionError() from exc | ||
except APIStatusError as exc: | ||
raise LLMStatusError(exc.message, exc.status_code) from exc | ||
except APIResponseValidationError as exc: | ||
raise LLMResponseError() from exc | ||
|
||
event.completion_tokens = response.usage.output_tokens | ||
event.prompt_tokens = response.usage.input_tokens | ||
event.total_tokens = response.usage.output_tokens + response.usage.input_tokens | ||
|
||
return response.content[0].text # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really a comment to this PR, but having an argument in
AnthropicClient
that is specific to OpenAI clearly illustrates that we should refactor this option away (both here and in the prompt object itself). In a separate PR/ticket of course.Maybe for example we could have a
expect_json
as a general boolen flag that each LLM client could implement in a way that makes sense for the particular model/API (by passing the appropriate option for the particular API or even by using post-processing).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. I was also thinking of moving
response_format
argument toLLMOptions
, which would make the interface cleaner. What do you think?cc @mhordynski
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
response_format
(or its non-OpenAI specific equivalent) should remain coupled with the prompt (i.e., be part of the PromptTemplate class), because whether we expect json fully depend on the prompt itself.