Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llm): add Anthropic LLM client #31

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ install_requires =
openai =
openai>=1.10.0
tiktoken>=0.6.0
anthropic =
anthropic>=0.25.8
faiss =
faiss-cpu>=1.8.0
examples =
Expand Down
6 changes: 6 additions & 0 deletions src/dbally/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dbally.views.structured import BaseStructuredView

from .__version__ import __version__
from ._exceptions import DbAllyError, LLMConnectionError, LLMError, LLMResponseError, LLMStatusError
from ._main import create_collection
from ._types import NOT_GIVEN, NotGiven
from .collection import Collection
Expand All @@ -22,6 +23,11 @@
"BaseStructuredView",
"DataFrameBaseView",
"ExecutionResult",
"DbAllyError",
"LLMError",
"LLMConnectionError",
"LLMResponseError",
"LLMStatusError",
"NotGiven",
"NOT_GIVEN",
]
42 changes: 42 additions & 0 deletions src/dbally/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
class DbAllyError(Exception):
"""
Base class for all exceptions raised by db-ally.
"""


class LLMError(DbAllyError):
"""
Base class for all exceptions raised by the LLMClient.
"""

def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


class LLMConnectionError(LLMError):
"""
Raised when there is an error connecting to the LLM API.
"""

def __init__(self, message: str = "Connection error.") -> None:
super().__init__(message)


class LLMStatusError(LLMError):
"""
Raised when an API response has a status code of 4xx or 5xx.
"""

def __init__(self, message: str, status_code: int) -> None:
super().__init__(message)
self.status_code = status_code


class LLMResponseError(LLMError):
"""
Raised when an API response has an invalid schema.
"""

def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None:
super().__init__(message)
122 changes: 122 additions & 0 deletions src/dbally/llm_client/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, List, Optional, Union

from anthropic import NOT_GIVEN as ANTHROPIC_NOT_GIVEN
from anthropic import APIConnectionError, APIResponseValidationError, APIStatusError
from anthropic import NotGiven as AnthropicNotGiven

from dbally.data_models.audit import LLMEvent
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.prompts.common_validation_utils import extract_system_prompt
from dbally.prompts.prompt_builder import ChatFormat

from .._exceptions import LLMConnectionError, LLMResponseError, LLMStatusError
from .._types import NOT_GIVEN, NotGiven


@dataclass
class AnthropicOptions(LLMOptions):
"""
Dataclass that represents all available LLM call options for the Anthropic API. Each of them is
described in the [Anthropic API documentation](https://docs.anthropic.com/en/api/messages).
"""

_not_given: ClassVar[Optional[AnthropicNotGiven]] = ANTHROPIC_NOT_GIVEN

max_tokens: Union[int, NotGiven] = NOT_GIVEN
stop_sequences: Union[Optional[List[str]], NotGiven] = NOT_GIVEN
temperature: Union[Optional[float], NotGiven] = NOT_GIVEN
top_k: Union[Optional[int], NotGiven] = NOT_GIVEN
top_p: Union[Optional[float], NotGiven] = NOT_GIVEN

def dict(self) -> Dict[str, Any]:
"""
Returns a dictionary representation of the LLMOptions instance.
If a value is None, it will be replaced with a provider-specific not-given sentinel,
except for max_tokens, which is set to 4096 if not provided.

Returns:
A dictionary representation of the LLMOptions instance.
"""
options = super().dict()

# Anthropic API requires max_tokens to be set
if isinstance(options["max_tokens"], AnthropicNotGiven) or options["max_tokens"] is None:
options["max_tokens"] = 4096

return options


class AnthropicClient(LLMClient[AnthropicOptions]):
"""
`AnthropicClient` is a class designed to interact with Anthropic's language model (LLM) endpoints,
particularly for the Claude models.

Args:
model_name: Name of the [Anthropic's model](https://docs.anthropic.com/claude/docs/models-overview) to be used,\
default is "claude-3-opus-20240229".
api_key: Anthropic's API key. If None ANTHROPIC_API_KEY environment variable will be used.
default_options: Default options to be used in the LLM calls.
"""

_options_cls = AnthropicOptions

def __init__(
self,
model_name: str = "claude-3-opus-20240229",
api_key: Optional[str] = None,
default_options: Optional[AnthropicOptions] = None,
) -> None:
try:
from anthropic import AsyncAnthropic # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError("You need to install anthropic package to use Claude models") from exc

super().__init__(model_name=model_name, default_options=default_options)
self._client = AsyncAnthropic(api_key=api_key)

async def call(
self,
prompt: Union[str, ChatFormat],
response_format: Optional[Dict[str, str]],
options: AnthropicOptions,
event: LLMEvent,
) -> str:
"""
Calls the Anthropic API endpoint.

Args:
prompt: Prompt as an Anthropic client style list.
response_format: Optional argument used in the OpenAI API - used to force the json output.
Copy link
Contributor

@ludwiktrammer ludwiktrammer May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really a comment to this PR, but having an argument in AnthropicClient that is specific to OpenAI clearly illustrates that we should refactor this option away (both here and in the prompt object itself). In a separate PR/ticket of course.

Maybe for example we could have a expect_json as a general boolen flag that each LLM client could implement in a way that makes sense for the particular model/API (by passing the appropriate option for the particular API or even by using post-processing).

Copy link
Collaborator Author

@micpst micpst May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I was also thinking of moving response_format argument to LLMOptions, which would make the interface cleaner. What do you think?

cc @mhordynski

Copy link
Contributor

@ludwiktrammer ludwiktrammer May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think response_format (or its non-OpenAI specific equivalent) should remain coupled with the prompt (i.e., be part of the PromptTemplate class), because whether we expect json fully depend on the prompt itself.

options: Additional settings used by the LLM.
event: Container with the prompt, LLM response and call metrics.

Returns:
Response string from LLM.

Raises:
LLMConnectionError: If there was an error connecting to the LLM API.
LLMStatusError: If the LLM API returned an error status.
LLMResponseError: If the LLM API returned an invalid response.
"""
prompt, system = extract_system_prompt(prompt)

try:
response = await self._client.messages.create(
messages=prompt,
model=self.model_name,
system=system,
**options.dict(), # type: ignore
)
except APIConnectionError as exc:
raise LLMConnectionError() from exc
except APIStatusError as exc:
raise LLMStatusError(exc.message, exc.status_code) from exc
except APIResponseValidationError as exc:
raise LLMResponseError() from exc

event.completion_tokens = response.usage.output_tokens
event.prompt_tokens = response.usage.input_tokens
event.total_tokens = response.usage.output_tokens + response.usage.input_tokens

return response.content[0].text # type: ignore
37 changes: 25 additions & 12 deletions src/dbally/llm_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@
from typing import ClassVar, Dict, List, Optional, Union

from openai import NOT_GIVEN as OPENAI_NOT_GIVEN
from openai import APIConnectionError, APIResponseValidationError, APIStatusError
from openai import NotGiven as OpenAINotGiven

from dbally.data_models.audit import LLMEvent
from dbally.llm_client.base import LLMClient
from dbally.llm_client.base import LLMClient, LLMOptions
from dbally.prompts import ChatFormat

from .._exceptions import LLMConnectionError, LLMResponseError, LLMStatusError
from .._types import NOT_GIVEN, NotGiven
from .base import LLMOptions


@dataclass
class OpenAIOptions(LLMOptions):
"""
Dataclass that represents all available LLM call options for the OpenAI API. Each of them is
described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create.)
described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create).
"""

_not_given: ClassVar[Optional[OpenAINotGiven]] = OPENAI_NOT_GIVEN
Expand All @@ -39,7 +40,7 @@ class OpenAIClient(LLMClient[OpenAIOptions]):
Args:
model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used,\
default is "gpt-3.5-turbo".
api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used
api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used.
default_options: Default options to be used in the LLM calls.
"""

Expand Down Expand Up @@ -71,25 +72,37 @@ async def call(

Args:
prompt: Prompt as an OpenAI client style list.
response_format: Optional argument used in the OpenAI API - used to force the json output
response_format: Optional argument used in the OpenAI API - used to force the json output.
options: Additional settings used by the LLM.
event: container with the prompt, LLM response and call metrics.
event: Container with the prompt, LLM response and call metrics.

Returns:
Response string from LLM.

Raises:
LLMConnectionError: If there is an issue with the connection to the LLM API.
LLMStatusError: If the LLM API returns an error status.
LLMResponseError: If there is an issue with the response from the LLM API.
"""

# only "turbo" models support response_format argument
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format
if "turbo" not in self.model_name:
response_format = None

response = await self._client.chat.completions.create(
messages=prompt,
model=self.model_name,
response_format=response_format,
**options.dict(), # type: ignore
)
try:
response = await self._client.chat.completions.create(
messages=prompt,
model=self.model_name,
response_format=response_format,
**options.dict(), # type: ignore
)
except APIConnectionError as exc:
raise LLMConnectionError() from exc
except APIStatusError as exc:
raise LLMStatusError(exc.message, exc.status_code) from exc
except APIResponseValidationError as exc:
raise LLMResponseError() from exc

event.completion_tokens = response.usage.completion_tokens
event.prompt_tokens = response.usage.prompt_tokens
Expand Down
12 changes: 10 additions & 2 deletions src/dbally/nl_responder/nl_responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
QueryExplainerPromptTemplate,
default_query_explainer_template,
)
from dbally.nl_responder.token_counters import count_tokens_for_huggingface, count_tokens_for_openai
from dbally.nl_responder.token_counters import (
count_tokens_for_anthropic,
count_tokens_for_huggingface,
count_tokens_for_openai,
)


class NLResponder:
Expand Down Expand Up @@ -74,7 +78,11 @@ async def generate_response(
fmt={"rows": rows, "question": question},
model=self._llm_client.model_name,
)

elif "claude" in self._llm_client.model_name:
tokens_count = await count_tokens_for_anthropic(
messages=self._nl_responder_prompt_template.chat,
fmt={"rows": rows, "question": question},
)
else:
tokens_count = count_tokens_for_huggingface(
messages=self._nl_responder_prompt_template.chat,
Expand Down
29 changes: 28 additions & 1 deletion src/dbally/nl_responder/token_counters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int:
"""
Counts the number of tokens in the messages for OpenAIs' models.
Counts the number of tokens in the messages for OpenAI's models.

Args:
messages: Messages to count tokens for.
Expand Down Expand Up @@ -34,6 +34,33 @@ def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: st
return num_tokens


async def count_tokens_for_anthropic(messages: ChatFormat, fmt: Dict[str, str]) -> int:
"""
Counts the number of tokens in the messages for Anthropic's models.

Args:
messages: Messages to count tokens for.
fmt: Arguments to be used with prompt.

Returns:
Number of tokens in the messages.

Raises:
ImportError: If anthropic package is not installed.
"""

try:
from anthropic._tokenizers import async_get_tokenizer # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError("You need to install anthropic package to use Claude models") from exc

tokenizer = await async_get_tokenizer()
num_tokens = 0
for message in messages:
num_tokens += len(tokenizer.encode(message["content"].format(**fmt)))
return num_tokens


def count_tokens_for_huggingface(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int:
"""
Counts the number of tokens in the messages for models available on HuggingFace.
Expand Down
20 changes: 20 additions & 0 deletions src/dbally/prompts/common_validation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,23 @@ def check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> Ch
"You need to format the following variables: {variables_to_check}"
)
return chat


def extract_system_prompt(chat: ChatFormat) -> Tuple[ChatFormat, str]:
mhordynski marked this conversation as resolved.
Show resolved Hide resolved
"""
Extracts the system prompt from the chat.

Args:
chat: Chat to extract the system prompt from

Returns:
Chat without the system prompt and system prompt
"""
system_prompt = ""
chat_without_system = []
for message in chat:
if message["role"] == "system":
system_prompt = message["content"]
else:
chat_without_system.append(message)
return tuple(chat_without_system), system_prompt
Loading