From 29d0239784ebc1fd1122f98c6f28f2d5a04a5028 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 14 May 2024 12:35:54 +0200 Subject: [PATCH 1/5] add anthropic llm client --- setup.cfg | 2 + src/dbally/llm_client/anthropic.py | 107 ++++++++++++++++++ src/dbally/nl_responder/nl_responder.py | 12 +- src/dbally/nl_responder/token_counters.py | 29 ++++- src/dbally/prompts/common_validation_utils.py | 20 ++++ 5 files changed, 167 insertions(+), 3 deletions(-) create mode 100644 src/dbally/llm_client/anthropic.py diff --git a/setup.cfg b/setup.cfg index c52d5583..8fe32f4c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,8 @@ install_requires = openai = openai>=1.10.0 tiktoken>=0.6.0 +anthropic = + anthropic>=0.25.8 faiss = faiss-cpu>=1.8.0 examples = diff --git a/src/dbally/llm_client/anthropic.py b/src/dbally/llm_client/anthropic.py new file mode 100644 index 00000000..f3dea548 --- /dev/null +++ b/src/dbally/llm_client/anthropic.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass +from typing import Any, ClassVar, Dict, List, Optional, Union + +from anthropic import NOT_GIVEN as ANTHROPIC_NOT_GIVEN +from anthropic import NotGiven as AnthropicNotGiven + +from dbally.data_models.audit import LLMEvent +from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.prompts.common_validation_utils import extract_system_prompt +from dbally.prompts.prompt_builder import ChatFormat + +from .._types import NOT_GIVEN, NotGiven + + +@dataclass +class AnthropicOptions(LLMOptions): + """ + Dataclass that represents all available LLM call options for the Anthropic API. Each of them is + described in the [Anthropic API documentation](https://docs.anthropic.com/en/api/messages) + """ + + _not_given: ClassVar[Optional[AnthropicNotGiven]] = ANTHROPIC_NOT_GIVEN + + max_tokens: Union[int, NotGiven] = NOT_GIVEN + stop_sequences: Union[Optional[List[str]], NotGiven] = NOT_GIVEN + temperature: Union[Optional[float], NotGiven] = NOT_GIVEN + top_k: Union[Optional[int], NotGiven] = NOT_GIVEN + top_p: Union[Optional[float], NotGiven] = NOT_GIVEN + + def dict(self) -> Dict[str, Any]: + """ + Returns a dictionary representation of the LLMOptions instance. + If a value is None, it will be replaced with the _not_given value. + + Returns: + A dictionary representation of the LLMOptions instance. + """ + options = super().dict() + + # Anthropic API requires max_tokens to be set + if isinstance(options["max_tokens"], AnthropicNotGiven) or options["max_tokens"] is None: + options["max_tokens"] = 256 + + return options + + +class AnthropicClient(LLMClient[AnthropicOptions]): + """ + `AnthropicClient` is a class designed to interact with Anthropic's language model (LLM) endpoints, + particularly for the Claude models. + + Args: + model_name: Name of the [Anthropic's model](https://docs.anthropic.com/claude/docs/models-overview) to be used, + default is "claude-3-opus-20240229". + api_key: Anthropic's API key. If None ANTHROPIC_API_KEY environment variable will be used + default_options: Default options to be used in the LLM calls. + """ + + _options_cls = AnthropicOptions + + def __init__( + self, + model_name: str = "claude-3-opus-20240229", + api_key: Optional[str] = None, + default_options: Optional[AnthropicOptions] = None, + ) -> None: + try: + from anthropic import AsyncAnthropic # pylint: disable=import-outside-toplevel + except ImportError as exc: + raise ImportError("You need to install anthropic package to use Claude models") from exc + + super().__init__(model_name=model_name, default_options=default_options) + self._client = AsyncAnthropic(api_key=api_key) + + async def call( + self, + prompt: Union[str, ChatFormat], + response_format: Optional[Dict[str, str]], + options: AnthropicOptions, + event: LLMEvent, + ) -> str: + """ + Calls the Anthropic API endpoint. + + Args: + prompt: Prompt as an Anthropic client style list. + response_format: Optional argument used in the OpenAI API - used to force the json output + options: Additional settings used by the LLM. + event: container with the prompt, LLM response and call metrics. + + Returns: + Response string from LLM. + """ + prompt, system = extract_system_prompt(prompt) + + response = await self._client.messages.create( + messages=prompt, + model=self.model_name, + system=system, + **options.dict(), # type: ignore + ) + + event.completion_tokens = response.usage.output_tokens + event.prompt_tokens = response.usage.input_tokens + event.total_tokens = response.usage.output_tokens + response.usage.input_tokens + + return response.content[0].text # type: ignore diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py index a0d093b4..490e98d6 100644 --- a/src/dbally/nl_responder/nl_responder.py +++ b/src/dbally/nl_responder/nl_responder.py @@ -11,7 +11,11 @@ QueryExplainerPromptTemplate, default_query_explainer_template, ) -from dbally.nl_responder.token_counters import count_tokens_for_huggingface, count_tokens_for_openai +from dbally.nl_responder.token_counters import ( + count_tokens_for_anthropic, + count_tokens_for_huggingface, + count_tokens_for_openai, +) class NLResponder: @@ -74,7 +78,11 @@ async def generate_response( fmt={"rows": rows, "question": question}, model=self._llm_client.model_name, ) - + elif "claude" in self._llm_client.model_name: + tokens_count = count_tokens_for_anthropic( + messages=self._nl_responder_prompt_template.chat, + fmt={"rows": rows, "question": question}, + ) else: tokens_count = count_tokens_for_huggingface( messages=self._nl_responder_prompt_template.chat, diff --git a/src/dbally/nl_responder/token_counters.py b/src/dbally/nl_responder/token_counters.py index 2b242eba..6656ad73 100644 --- a/src/dbally/nl_responder/token_counters.py +++ b/src/dbally/nl_responder/token_counters.py @@ -5,7 +5,7 @@ def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int: """ - Counts the number of tokens in the messages for OpenAIs' models. + Counts the number of tokens in the messages for OpenAI's models. Args: messages: Messages to count tokens for. @@ -34,6 +34,33 @@ def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: st return num_tokens +def count_tokens_for_anthropic(messages: ChatFormat, fmt: Dict[str, str]) -> int: + """ + Counts the number of tokens in the messages for Anthropic's models. + + Args: + messages: Messages to count tokens for. + fmt: Arguments to be used with prompt. + + Returns: + Number of tokens in the messages. + + Raises: + ImportError: If anthropic package is not installed. + """ + + try: + from anthropic._tokenizers import sync_get_tokenizer # pylint: disable=import-outside-toplevel + except ImportError as exc: + raise ImportError("You need to install anthropic package to use Claude models") from exc + + tokenizer = sync_get_tokenizer() + num_tokens = 0 + for message in messages: + num_tokens += len(tokenizer.encode(message["content"].format(**fmt))) + return num_tokens + + def count_tokens_for_huggingface(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int: """ Counts the number of tokens in the messages for models available on HuggingFace. diff --git a/src/dbally/prompts/common_validation_utils.py b/src/dbally/prompts/common_validation_utils.py index f62d72b1..8912ec6c 100644 --- a/src/dbally/prompts/common_validation_utils.py +++ b/src/dbally/prompts/common_validation_utils.py @@ -47,3 +47,23 @@ def check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> Ch "You need to format the following variables: {variables_to_check}" ) return chat + + +def extract_system_prompt(chat: ChatFormat) -> Tuple[ChatFormat, str]: + """ + Extracts the system prompt from the chat. + + Args: + chat: chat to extract the system prompt from + + Returns: + chat without the system prompt and system prompt + """ + system_prompt = "" + chat_without_system = [] + for message in chat: + if message["role"] == "system": + system_prompt = message["content"] + else: + chat_without_system.append(message) + return tuple(chat_without_system), system_prompt From 483490aff57979b618ab40ec0b0d9ea54328e100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 15 May 2024 10:19:57 +0200 Subject: [PATCH 2/5] add exceptions handling --- src/dbally/__init__.py | 6 ++++ src/dbally/_exceptions.py | 42 ++++++++++++++++++++++++++ src/dbally/llm_client/anthropic.py | 41 +++++++++++++++++-------- src/dbally/llm_client/openai_client.py | 37 +++++++++++++++-------- 4 files changed, 101 insertions(+), 25 deletions(-) create mode 100644 src/dbally/_exceptions.py diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index bb3b6486..56951849 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -8,6 +8,7 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ +from ._exceptions import DballyError, LLMConnectionError, LLMError, LLMResponseError, LLMStatusError from ._main import create_collection from ._types import NOT_GIVEN, NotGiven from .collection import Collection @@ -22,6 +23,11 @@ "BaseStructuredView", "DataFrameBaseView", "ExecutionResult", + "DballyError", + "LLMError", + "LLMConnectionError", + "LLMResponseError", + "LLMStatusError", "NotGiven", "NOT_GIVEN", ] diff --git a/src/dbally/_exceptions.py b/src/dbally/_exceptions.py new file mode 100644 index 00000000..ea175648 --- /dev/null +++ b/src/dbally/_exceptions.py @@ -0,0 +1,42 @@ +class DballyError(Exception): + """ + Base class for all exceptions raised by dbally. + """ + + +class LLMError(DballyError): + """ + Base class for all exceptions raised by the LLMClient. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + +class LLMConnectionError(LLMError): + """ + Raised when there is an error connecting to the LLM API. + """ + + def __init__(self, message: str = "Connection error.") -> None: + super().__init__(message) + + +class LLMStatusError(LLMError): + """ + Raised when an API response has a status code of 4xx or 5xx. + """ + + def __init__(self, message: str, status_code: int) -> None: + super().__init__(message) + self.status_code = status_code + + +class LLMResponseError(LLMError): + """ + Raised when an API response has an invalid schema. + """ + + def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: + super().__init__(message) diff --git a/src/dbally/llm_client/anthropic.py b/src/dbally/llm_client/anthropic.py index f3dea548..933717ca 100644 --- a/src/dbally/llm_client/anthropic.py +++ b/src/dbally/llm_client/anthropic.py @@ -2,6 +2,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union from anthropic import NOT_GIVEN as ANTHROPIC_NOT_GIVEN +from anthropic import APIConnectionError, APIResponseValidationError, APIStatusError from anthropic import NotGiven as AnthropicNotGiven from dbally.data_models.audit import LLMEvent @@ -9,6 +10,7 @@ from dbally.prompts.common_validation_utils import extract_system_prompt from dbally.prompts.prompt_builder import ChatFormat +from .._exceptions import LLMConnectionError, LLMResponseError, LLMStatusError from .._types import NOT_GIVEN, NotGiven @@ -16,7 +18,7 @@ class AnthropicOptions(LLMOptions): """ Dataclass that represents all available LLM call options for the Anthropic API. Each of them is - described in the [Anthropic API documentation](https://docs.anthropic.com/en/api/messages) + described in the [Anthropic API documentation](https://docs.anthropic.com/en/api/messages). """ _not_given: ClassVar[Optional[AnthropicNotGiven]] = ANTHROPIC_NOT_GIVEN @@ -30,7 +32,8 @@ class AnthropicOptions(LLMOptions): def dict(self) -> Dict[str, Any]: """ Returns a dictionary representation of the LLMOptions instance. - If a value is None, it will be replaced with the _not_given value. + If a value is None, it will be replaced with a provider-specific not-given sentinel, + except for max_tokens, which is set to 256 if not provided. Returns: A dictionary representation of the LLMOptions instance. @@ -39,7 +42,7 @@ def dict(self) -> Dict[str, Any]: # Anthropic API requires max_tokens to be set if isinstance(options["max_tokens"], AnthropicNotGiven) or options["max_tokens"] is None: - options["max_tokens"] = 256 + options["max_tokens"] = 4096 return options @@ -50,9 +53,9 @@ class AnthropicClient(LLMClient[AnthropicOptions]): particularly for the Claude models. Args: - model_name: Name of the [Anthropic's model](https://docs.anthropic.com/claude/docs/models-overview) to be used, + model_name: Name of the [Anthropic's model](https://docs.anthropic.com/claude/docs/models-overview) to be used,\ default is "claude-3-opus-20240229". - api_key: Anthropic's API key. If None ANTHROPIC_API_KEY environment variable will be used + api_key: Anthropic's API key. If None ANTHROPIC_API_KEY environment variable will be used. default_options: Default options to be used in the LLM calls. """ @@ -84,21 +87,33 @@ async def call( Args: prompt: Prompt as an Anthropic client style list. - response_format: Optional argument used in the OpenAI API - used to force the json output + response_format: Optional argument used in the OpenAI API - used to force the json output. options: Additional settings used by the LLM. - event: container with the prompt, LLM response and call metrics. + event: Container with the prompt, LLM response and call metrics. Returns: Response string from LLM. + + Raises: + LLMConnectionError: If there was an error connecting to the LLM API. + LLMStatusError: If the LLM API returned an error status. + LLMResponseError: If the LLM API returned an invalid response. """ prompt, system = extract_system_prompt(prompt) - response = await self._client.messages.create( - messages=prompt, - model=self.model_name, - system=system, - **options.dict(), # type: ignore - ) + try: + response = await self._client.messages.create( + messages=prompt, + model=self.model_name, + system=system, + **options.dict(), # type: ignore + ) + except APIConnectionError as exc: + raise LLMConnectionError() from exc + except APIStatusError as exc: + raise LLMStatusError(exc.message, exc.status_code) from exc + except APIResponseValidationError as exc: + raise LLMResponseError() from exc event.completion_tokens = response.usage.output_tokens event.prompt_tokens = response.usage.input_tokens diff --git a/src/dbally/llm_client/openai_client.py b/src/dbally/llm_client/openai_client.py index 4458698b..48fb89eb 100644 --- a/src/dbally/llm_client/openai_client.py +++ b/src/dbally/llm_client/openai_client.py @@ -2,21 +2,22 @@ from typing import ClassVar, Dict, List, Optional, Union from openai import NOT_GIVEN as OPENAI_NOT_GIVEN +from openai import APIConnectionError, APIResponseValidationError, APIStatusError from openai import NotGiven as OpenAINotGiven from dbally.data_models.audit import LLMEvent -from dbally.llm_client.base import LLMClient +from dbally.llm_client.base import LLMClient, LLMOptions from dbally.prompts import ChatFormat +from .._exceptions import LLMConnectionError, LLMResponseError, LLMStatusError from .._types import NOT_GIVEN, NotGiven -from .base import LLMOptions @dataclass class OpenAIOptions(LLMOptions): """ Dataclass that represents all available LLM call options for the OpenAI API. Each of them is - described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create.) + described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create). """ _not_given: ClassVar[Optional[OpenAINotGiven]] = OPENAI_NOT_GIVEN @@ -39,7 +40,7 @@ class OpenAIClient(LLMClient[OpenAIOptions]): Args: model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used,\ default is "gpt-3.5-turbo". - api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used + api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used. default_options: Default options to be used in the LLM calls. """ @@ -71,12 +72,17 @@ async def call( Args: prompt: Prompt as an OpenAI client style list. - response_format: Optional argument used in the OpenAI API - used to force the json output + response_format: Optional argument used in the OpenAI API - used to force the json output. options: Additional settings used by the LLM. - event: container with the prompt, LLM response and call metrics. + event: Container with the prompt, LLM response and call metrics. Returns: Response string from LLM. + + Raises: + LLMConnectionError: If there is an issue with the connection to the LLM API. + LLMStatusError: If the LLM API returns an error status. + LLMResponseError: If there is an issue with the response from the LLM API. """ # only "turbo" models support response_format argument @@ -84,12 +90,19 @@ async def call( if "turbo" not in self.model_name: response_format = None - response = await self._client.chat.completions.create( - messages=prompt, - model=self.model_name, - response_format=response_format, - **options.dict(), # type: ignore - ) + try: + response = await self._client.chat.completions.create( + messages=prompt, + model=self.model_name, + response_format=response_format, + **options.dict(), # type: ignore + ) + except APIConnectionError as exc: + raise LLMConnectionError() from exc + except APIStatusError as exc: + raise LLMStatusError(exc.message, exc.status_code) from exc + except APIResponseValidationError as exc: + raise LLMResponseError() from exc event.completion_tokens = response.usage.completion_tokens event.prompt_tokens = response.usage.prompt_tokens From 15c72031be0a954d5874b9c4e106e34f05e9da43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 15 May 2024 10:20:31 +0200 Subject: [PATCH 3/5] fix docs --- src/dbally/llm_client/anthropic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dbally/llm_client/anthropic.py b/src/dbally/llm_client/anthropic.py index 933717ca..2574e84a 100644 --- a/src/dbally/llm_client/anthropic.py +++ b/src/dbally/llm_client/anthropic.py @@ -33,7 +33,7 @@ def dict(self) -> Dict[str, Any]: """ Returns a dictionary representation of the LLMOptions instance. If a value is None, it will be replaced with a provider-specific not-given sentinel, - except for max_tokens, which is set to 256 if not provided. + except for max_tokens, which is set to 4096 if not provided. Returns: A dictionary representation of the LLMOptions instance. From 382411dffdbbf08156bc81f19689cf4ab0f86a40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 15 May 2024 10:23:44 +0200 Subject: [PATCH 4/5] fix docs --- src/dbally/prompts/common_validation_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dbally/prompts/common_validation_utils.py b/src/dbally/prompts/common_validation_utils.py index 8912ec6c..ef0cabc9 100644 --- a/src/dbally/prompts/common_validation_utils.py +++ b/src/dbally/prompts/common_validation_utils.py @@ -54,10 +54,10 @@ def extract_system_prompt(chat: ChatFormat) -> Tuple[ChatFormat, str]: Extracts the system prompt from the chat. Args: - chat: chat to extract the system prompt from + chat: Chat to extract the system prompt from Returns: - chat without the system prompt and system prompt + Chat without the system prompt and system prompt """ system_prompt = "" chat_without_system = [] From 030c8f51f8419b621029e6415ab9eda209603107 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 15 May 2024 12:36:19 +0200 Subject: [PATCH 5/5] add async tokenizer + rename Exceptions --- src/dbally/__init__.py | 4 ++-- src/dbally/_exceptions.py | 6 +++--- src/dbally/nl_responder/nl_responder.py | 2 +- src/dbally/nl_responder/token_counters.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index 56951849..b1a0e28d 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -8,7 +8,7 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ -from ._exceptions import DballyError, LLMConnectionError, LLMError, LLMResponseError, LLMStatusError +from ._exceptions import DbAllyError, LLMConnectionError, LLMError, LLMResponseError, LLMStatusError from ._main import create_collection from ._types import NOT_GIVEN, NotGiven from .collection import Collection @@ -23,7 +23,7 @@ "BaseStructuredView", "DataFrameBaseView", "ExecutionResult", - "DballyError", + "DbAllyError", "LLMError", "LLMConnectionError", "LLMResponseError", diff --git a/src/dbally/_exceptions.py b/src/dbally/_exceptions.py index ea175648..adeaeb51 100644 --- a/src/dbally/_exceptions.py +++ b/src/dbally/_exceptions.py @@ -1,10 +1,10 @@ -class DballyError(Exception): +class DbAllyError(Exception): """ - Base class for all exceptions raised by dbally. + Base class for all exceptions raised by db-ally. """ -class LLMError(DballyError): +class LLMError(DbAllyError): """ Base class for all exceptions raised by the LLMClient. """ diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py index 490e98d6..a925b28a 100644 --- a/src/dbally/nl_responder/nl_responder.py +++ b/src/dbally/nl_responder/nl_responder.py @@ -79,7 +79,7 @@ async def generate_response( model=self._llm_client.model_name, ) elif "claude" in self._llm_client.model_name: - tokens_count = count_tokens_for_anthropic( + tokens_count = await count_tokens_for_anthropic( messages=self._nl_responder_prompt_template.chat, fmt={"rows": rows, "question": question}, ) diff --git a/src/dbally/nl_responder/token_counters.py b/src/dbally/nl_responder/token_counters.py index 6656ad73..041a11e0 100644 --- a/src/dbally/nl_responder/token_counters.py +++ b/src/dbally/nl_responder/token_counters.py @@ -34,7 +34,7 @@ def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: st return num_tokens -def count_tokens_for_anthropic(messages: ChatFormat, fmt: Dict[str, str]) -> int: +async def count_tokens_for_anthropic(messages: ChatFormat, fmt: Dict[str, str]) -> int: """ Counts the number of tokens in the messages for Anthropic's models. @@ -50,11 +50,11 @@ def count_tokens_for_anthropic(messages: ChatFormat, fmt: Dict[str, str]) -> int """ try: - from anthropic._tokenizers import sync_get_tokenizer # pylint: disable=import-outside-toplevel + from anthropic._tokenizers import async_get_tokenizer # pylint: disable=import-outside-toplevel except ImportError as exc: raise ImportError("You need to install anthropic package to use Claude models") from exc - tokenizer = sync_get_tokenizer() + tokenizer = await async_get_tokenizer() num_tokens = 0 for message in messages: num_tokens += len(tokenizer.encode(message["content"].format(**fmt)))