diff --git a/setup.cfg b/setup.cfg index c52d5583..8fe32f4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,8 @@ install_requires = openai = openai>=1.10.0 tiktoken>=0.6.0 +anthropic = + anthropic>=0.25.8 faiss = faiss-cpu>=1.8.0 examples = diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index bb3b6486..b1a0e28d 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -8,6 +8,7 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ +from ._exceptions import DbAllyError, LLMConnectionError, LLMError, LLMResponseError, LLMStatusError from ._main import create_collection from ._types import NOT_GIVEN, NotGiven from .collection import Collection @@ -22,6 +23,11 @@ "BaseStructuredView", "DataFrameBaseView", "ExecutionResult", + "DbAllyError", + "LLMError", + "LLMConnectionError", + "LLMResponseError", + "LLMStatusError", "NotGiven", "NOT_GIVEN", ] diff --git a/src/dbally/_exceptions.py b/src/dbally/_exceptions.py new file mode 100644 index 00000000..adeaeb51 --- /dev/null +++ b/src/dbally/_exceptions.py @@ -0,0 +1,42 @@ +class DbAllyError(Exception): + """ + Base class for all exceptions raised by db-ally. + """ + + +class LLMError(DbAllyError): + """ + Base class for all exceptions raised by the LLMClient. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + +class LLMConnectionError(LLMError): + """ + Raised when there is an error connecting to the LLM API. + """ + + def __init__(self, message: str = "Connection error.") -> None: + super().__init__(message) + + +class LLMStatusError(LLMError): + """ + Raised when an API response has a status code of 4xx or 5xx. + """ + + def __init__(self, message: str, status_code: int) -> None: + super().__init__(message) + self.status_code = status_code + + +class LLMResponseError(LLMError): + """ + Raised when an API response has an invalid schema. + """ + + def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: + super().__init__(message) diff --git a/src/dbally/llm_client/anthropic.py b/src/dbally/llm_client/anthropic.py new file mode 100644 index 00000000..2574e84a --- /dev/null +++ b/src/dbally/llm_client/anthropic.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass +from typing import Any, ClassVar, Dict, List, Optional, Union + +from anthropic import NOT_GIVEN as ANTHROPIC_NOT_GIVEN +from anthropic import APIConnectionError, APIResponseValidationError, APIStatusError +from anthropic import NotGiven as AnthropicNotGiven + +from dbally.data_models.audit import LLMEvent +from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.prompts.common_validation_utils import extract_system_prompt +from dbally.prompts.prompt_builder import ChatFormat + +from .._exceptions import LLMConnectionError, LLMResponseError, LLMStatusError +from .._types import NOT_GIVEN, NotGiven + + +@dataclass +class AnthropicOptions(LLMOptions): + """ + Dataclass that represents all available LLM call options for the Anthropic API. Each of them is + described in the [Anthropic API documentation](https://docs.anthropic.com/en/api/messages). + """ + + _not_given: ClassVar[Optional[AnthropicNotGiven]] = ANTHROPIC_NOT_GIVEN + + max_tokens: Union[int, NotGiven] = NOT_GIVEN + stop_sequences: Union[Optional[List[str]], NotGiven] = NOT_GIVEN + temperature: Union[Optional[float], NotGiven] = NOT_GIVEN + top_k: Union[Optional[int], NotGiven] = NOT_GIVEN + top_p: Union[Optional[float], NotGiven] = NOT_GIVEN + + def dict(self) -> Dict[str, Any]: + """ + Returns a dictionary representation of the LLMOptions instance. + If a value is None, it will be replaced with a provider-specific not-given sentinel, + except for max_tokens, which is set to 4096 if not provided. + + Returns: + A dictionary representation of the LLMOptions instance. + """ + options = super().dict() + + # Anthropic API requires max_tokens to be set + if isinstance(options["max_tokens"], AnthropicNotGiven) or options["max_tokens"] is None: + options["max_tokens"] = 4096 + + return options + + +class AnthropicClient(LLMClient[AnthropicOptions]): + """ + `AnthropicClient` is a class designed to interact with Anthropic's language model (LLM) endpoints, + particularly for the Claude models. + + Args: + model_name: Name of the [Anthropic's model](https://docs.anthropic.com/claude/docs/models-overview) to be used,\ + default is "claude-3-opus-20240229". + api_key: Anthropic's API key. If None ANTHROPIC_API_KEY environment variable will be used. + default_options: Default options to be used in the LLM calls. + """ + + _options_cls = AnthropicOptions + + def __init__( + self, + model_name: str = "claude-3-opus-20240229", + api_key: Optional[str] = None, + default_options: Optional[AnthropicOptions] = None, + ) -> None: + try: + from anthropic import AsyncAnthropic # pylint: disable=import-outside-toplevel + except ImportError as exc: + raise ImportError("You need to install anthropic package to use Claude models") from exc + + super().__init__(model_name=model_name, default_options=default_options) + self._client = AsyncAnthropic(api_key=api_key) + + async def call( + self, + prompt: Union[str, ChatFormat], + response_format: Optional[Dict[str, str]], + options: AnthropicOptions, + event: LLMEvent, + ) -> str: + """ + Calls the Anthropic API endpoint. + + Args: + prompt: Prompt as an Anthropic client style list. + response_format: Optional argument used in the OpenAI API - used to force the json output. + options: Additional settings used by the LLM. + event: Container with the prompt, LLM response and call metrics. + + Returns: + Response string from LLM. + + Raises: + LLMConnectionError: If there was an error connecting to the LLM API. + LLMStatusError: If the LLM API returned an error status. + LLMResponseError: If the LLM API returned an invalid response. + """ + prompt, system = extract_system_prompt(prompt) + + try: + response = await self._client.messages.create( + messages=prompt, + model=self.model_name, + system=system, + **options.dict(), # type: ignore + ) + except APIConnectionError as exc: + raise LLMConnectionError() from exc + except APIStatusError as exc: + raise LLMStatusError(exc.message, exc.status_code) from exc + except APIResponseValidationError as exc: + raise LLMResponseError() from exc + + event.completion_tokens = response.usage.output_tokens + event.prompt_tokens = response.usage.input_tokens + event.total_tokens = response.usage.output_tokens + response.usage.input_tokens + + return response.content[0].text # type: ignore diff --git a/src/dbally/llm_client/openai_client.py b/src/dbally/llm_client/openai_client.py index 4458698b..48fb89eb 100644 --- a/src/dbally/llm_client/openai_client.py +++ b/src/dbally/llm_client/openai_client.py @@ -2,21 +2,22 @@ from typing import ClassVar, Dict, List, Optional, Union from openai import NOT_GIVEN as OPENAI_NOT_GIVEN +from openai import APIConnectionError, APIResponseValidationError, APIStatusError from openai import NotGiven as OpenAINotGiven from dbally.data_models.audit import LLMEvent -from dbally.llm_client.base import LLMClient +from dbally.llm_client.base import LLMClient, LLMOptions from dbally.prompts import ChatFormat +from .._exceptions import LLMConnectionError, LLMResponseError, LLMStatusError from .._types import NOT_GIVEN, NotGiven -from .base import LLMOptions @dataclass class OpenAIOptions(LLMOptions): """ Dataclass that represents all available LLM call options for the OpenAI API. Each of them is - described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create.) + described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create). """ _not_given: ClassVar[Optional[OpenAINotGiven]] = OPENAI_NOT_GIVEN @@ -39,7 +40,7 @@ class OpenAIClient(LLMClient[OpenAIOptions]): Args: model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used,\ default is "gpt-3.5-turbo". - api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used + api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used. default_options: Default options to be used in the LLM calls. """ @@ -71,12 +72,17 @@ async def call( Args: prompt: Prompt as an OpenAI client style list. - response_format: Optional argument used in the OpenAI API - used to force the json output + response_format: Optional argument used in the OpenAI API - used to force the json output. options: Additional settings used by the LLM. - event: container with the prompt, LLM response and call metrics. + event: Container with the prompt, LLM response and call metrics. Returns: Response string from LLM. + + Raises: + LLMConnectionError: If there is an issue with the connection to the LLM API. + LLMStatusError: If the LLM API returns an error status. + LLMResponseError: If there is an issue with the response from the LLM API. """ # only "turbo" models support response_format argument @@ -84,12 +90,19 @@ async def call( if "turbo" not in self.model_name: response_format = None - response = await self._client.chat.completions.create( - messages=prompt, - model=self.model_name, - response_format=response_format, - **options.dict(), # type: ignore - ) + try: + response = await self._client.chat.completions.create( + messages=prompt, + model=self.model_name, + response_format=response_format, + **options.dict(), # type: ignore + ) + except APIConnectionError as exc: + raise LLMConnectionError() from exc + except APIStatusError as exc: + raise LLMStatusError(exc.message, exc.status_code) from exc + except APIResponseValidationError as exc: + raise LLMResponseError() from exc event.completion_tokens = response.usage.completion_tokens event.prompt_tokens = response.usage.prompt_tokens diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py index a0d093b4..a925b28a 100644 --- a/src/dbally/nl_responder/nl_responder.py +++ b/src/dbally/nl_responder/nl_responder.py @@ -11,7 +11,11 @@ QueryExplainerPromptTemplate, default_query_explainer_template, ) -from dbally.nl_responder.token_counters import count_tokens_for_huggingface, count_tokens_for_openai +from dbally.nl_responder.token_counters import ( + count_tokens_for_anthropic, + count_tokens_for_huggingface, + count_tokens_for_openai, +) class NLResponder: @@ -74,7 +78,11 @@ async def generate_response( fmt={"rows": rows, "question": question}, model=self._llm_client.model_name, ) - + elif "claude" in self._llm_client.model_name: + tokens_count = await count_tokens_for_anthropic( + messages=self._nl_responder_prompt_template.chat, + fmt={"rows": rows, "question": question}, + ) else: tokens_count = count_tokens_for_huggingface( messages=self._nl_responder_prompt_template.chat, diff --git a/src/dbally/nl_responder/token_counters.py b/src/dbally/nl_responder/token_counters.py index 2b242eba..041a11e0 100644 --- a/src/dbally/nl_responder/token_counters.py +++ b/src/dbally/nl_responder/token_counters.py @@ -5,7 +5,7 @@ def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int: """ - Counts the number of tokens in the messages for OpenAIs' models. + Counts the number of tokens in the messages for OpenAI's models. Args: messages: Messages to count tokens for. @@ -34,6 +34,33 @@ def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: st return num_tokens +async def count_tokens_for_anthropic(messages: ChatFormat, fmt: Dict[str, str]) -> int: + """ + Counts the number of tokens in the messages for Anthropic's models. + + Args: + messages: Messages to count tokens for. + fmt: Arguments to be used with prompt. + + Returns: + Number of tokens in the messages. + + Raises: + ImportError: If anthropic package is not installed. + """ + + try: + from anthropic._tokenizers import async_get_tokenizer # pylint: disable=import-outside-toplevel + except ImportError as exc: + raise ImportError("You need to install anthropic package to use Claude models") from exc + + tokenizer = await async_get_tokenizer() + num_tokens = 0 + for message in messages: + num_tokens += len(tokenizer.encode(message["content"].format(**fmt))) + return num_tokens + + def count_tokens_for_huggingface(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int: """ Counts the number of tokens in the messages for models available on HuggingFace. diff --git a/src/dbally/prompts/common_validation_utils.py b/src/dbally/prompts/common_validation_utils.py index f62d72b1..ef0cabc9 100644 --- a/src/dbally/prompts/common_validation_utils.py +++ b/src/dbally/prompts/common_validation_utils.py @@ -47,3 +47,23 @@ def check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> Ch "You need to format the following variables: {variables_to_check}" ) return chat + + +def extract_system_prompt(chat: ChatFormat) -> Tuple[ChatFormat, str]: + """ + Extracts the system prompt from the chat. + + Args: + chat: Chat to extract the system prompt from + + Returns: + Chat without the system prompt and system prompt + """ + system_prompt = "" + chat_without_system = [] + for message in chat: + if message["role"] == "system": + system_prompt = message["content"] + else: + chat_without_system.append(message) + return tuple(chat_without_system), system_prompt