Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(embedding): add litellm embeddings #37

Merged
merged 4 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/how-to/use_chromadb_store.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ or [set up Chromadb in the client/server mode](https://docs.trychroma.com/usage-
chroma_client = chromadb.HttpClient(host='localhost', port=8000)
```

Next, you can either use [one of dbally embedding clients][dbally.embedding_client.EmbeddingClient], such as [OpenAiEmbeddingClient][dbally.embedding_client.OpenAiEmbeddingClient]
Next, you can either use [one of dbally embedding clients][dbally.embeddings.EmbeddingClient], such as [LiteLLMEmbeddingClient][dbally.embeddings.LiteLLMEmbeddingClient]

```python
from dbally.embedding_client import OpenAiEmbeddingClient
from dbally.embeddings.litellm import LiteLLMEmbeddingClient

embedding_client=OpenAiEmbeddingClient(
embedding_client=LiteLLMEmbeddingClient(
micpst marked this conversation as resolved.
Show resolved Hide resolved
api_key="your-api-key",
)

Expand Down
4 changes: 2 additions & 2 deletions docs/how-to/use_custom_similarity_fetcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ breeds_similarity = SimilarityIndex(
store=FaissStore(
index_dir="./similarity_indexes",
index_name="breeds_similarity",
embedding_client=OpenAiEmbeddingClient(
embedding_client=LiteLLMEmbeddingClient(
api_key="your-api-key",
)
)
```

In this example, we used the FaissStore, which utilizes the `faiss` library for rapid similarity search. We also employed the `OpenAiEmbeddingClient` to get the semantic embeddings for the dog breeds. Depending on your needs, you can use a different built-in store or create [a custom one](../how-to/use_custom_similarity_store.md).
In this example, we used the FaissStore, which utilizes the `faiss` library for rapid similarity search. We also employed the `LiteLLMEmbeddingClient` to get the semantic embeddings for the dog breeds. Depending on your needs, you can use a different built-in store or create [a custom one](../how-to/use_custom_similarity_store.md).

## Using the Similarity Index

Expand Down
6 changes: 3 additions & 3 deletions docs/quickstart/quickstart2.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ Next, let's define a store that will store the country names and can be used to

```python
from dbally.similarity import FaissStore
from dbally.embedding_client.openai import OpenAiEmbeddingClient
from dbally.embeddings.litellm import LiteLLMEmbeddingClient

country_store = FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=OpenAiEmbeddingClient(
embedding_client=LiteLLMEmbeddingClient(
api_key="your-api-key",
)
)
```

In this example, we used the `FaissStore` store, which employs the `faiss` library for fast similarity search. We also used the `OpenAiEmbeddingClient` to get the semantic embeddings for the country names. Replace `your-api-key` with your OpenAI API key.
In this example, we used the `FaissStore` store, which employs the `faiss` library for fast similarity search. We also used the `LiteLLMEmbeddingClient` to get the semantic embeddings for the country names. Replace `your-api-key` with your OpenAI API key.

Finally, let's define the similarity index:

Expand Down
4 changes: 2 additions & 2 deletions docs/quickstart/quickstart2_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dbally import decorators, SqlAlchemyBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex
from dbally.embedding_client.openai import OpenAiEmbeddingClient
from dbally.embeddings.litellm import LiteLLMEmbeddingClient
from dbally.llms.litellm import LiteLLM

engine = create_engine('sqlite:///candidates.db')
Expand All @@ -30,7 +30,7 @@
store=FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=OpenAiEmbeddingClient(
embedding_client=LiteLLMEmbeddingClient(
api_key=os.environ["OPENAI_API_KEY"],
)
),
Expand Down
4 changes: 2 additions & 2 deletions docs/quickstart/quickstart3_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult
from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex
from dbally.embedding_client.openai import OpenAiEmbeddingClient
from dbally.embeddings.litellm import LiteLLMEmbeddingClient
from dbally.llms.litellm import LiteLLM

engine = create_engine('sqlite:///candidates.db')
Expand All @@ -31,7 +31,7 @@
store=FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=OpenAiEmbeddingClient(
embedding_client=LiteLLMEmbeddingClient(
api_key=os.environ["OPENAI_API_KEY"],
)
),
Expand Down
23 changes: 23 additions & 0 deletions src/dbally/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from ._main import create_collection
from ._types import NOT_GIVEN, NotGiven
from .collection import Collection
from .embeddings._exceptions import (
EmbeddingConnectionError,
EmbeddingError,
EmbeddingResponseError,
EmbeddingStatusError,
)
from .llms.clients._exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError

__all__ = [
Expand All @@ -25,10 +31,27 @@
"DataFrameBaseView",
"ExecutionResult",
"DbAllyError",
"EmbeddingError",
"EmbeddingConnectionError",
"EmbeddingResponseError",
"EmbeddingStatusError",
"LLMError",
"LLMConnectionError",
"LLMResponseError",
"LLMStatusError",
"NotGiven",
"NOT_GIVEN",
]

# Update the __module__ attribute for exported symbols so that
# error messages point to this module instead of the module
# it was originally defined in, e.g.
# dbally._exceptions.LLMError -> dbally.LLMError
__locals = locals()
for __name in __all__:
if not __name.startswith("__"):
try:
__locals[__name].__module__ = "dbally"
except (TypeError, AttributeError):
# Some of our exported symbols are builtins which we can't set attributes for.
pass
4 changes: 0 additions & 4 deletions src/dbally/embedding_client/__init__.py

This file was deleted.

19 changes: 0 additions & 19 deletions src/dbally/embedding_client/base.py

This file was deleted.

52 changes: 0 additions & 52 deletions src/dbally/embedding_client/openai.py

This file was deleted.

4 changes: 4 additions & 0 deletions src/dbally/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import EmbeddingClient
from .litellm import LiteLLMEmbeddingClient

__all__ = ["EmbeddingClient", "LiteLLMEmbeddingClient"]
39 changes: 39 additions & 0 deletions src/dbally/embeddings/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from .._exceptions import DbAllyError


class EmbeddingError(DbAllyError):
"""
Base class for all exceptions raised by the EmbeddingClient.
"""

def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


class EmbeddingConnectionError(EmbeddingError):
"""
Raised when there is an error connecting to the embedding API.
"""

def __init__(self, message: str = "Connection error.") -> None:
super().__init__(message)


class EmbeddingStatusError(EmbeddingError):
"""
Raised when an API response has a status code of 4xx or 5xx.
"""

def __init__(self, message: str, status_code: int) -> None:
super().__init__(message)
self.status_code = status_code


class EmbeddingResponseError(EmbeddingError):
"""
Raised when an API response has an invalid schema.
"""

def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None:
super().__init__(message)
20 changes: 20 additions & 0 deletions src/dbally/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
from typing import List


class EmbeddingClient(ABC):
"""
Abstract client for creating text embeddings.
"""

@abstractmethod
async def get_embeddings(self, data: List[str]) -> List[List[float]]:
"""
Creates embeddings for the given strings.
Args:
data: List of strings to get embeddings for.
Returns:
List of embeddings for the given strings.
"""
85 changes: 85 additions & 0 deletions src/dbally/embeddings/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Dict, List, Optional

try:
import litellm

HAVE_LITELLM = True
except ImportError:
HAVE_LITELLM = False

from dbally.embeddings.base import EmbeddingClient

from ._exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError


class LiteLLMEmbeddingClient(EmbeddingClient):
"""
Client for creating text embeddings using LiteLLM API.
"""

def __init__(
self,
model: str = "text-embedding-3-small",
options: Optional[Dict] = None,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
) -> None:
"""
Constructs the LiteLLMEmbeddingClient.

Args:
model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\
to be used. Default is "text-embedding-3-small".
options: Additional options to pass to the LiteLLM API.
api_base: The API endpoint you want to call the model with.
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used,
for more information, follow the instructions for your specific vendor in the\
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding).
api_version: The API version for the call.

Raises:
ImportError: If the litellm package is not installed.
"""
if not HAVE_LITELLM:
raise ImportError("You need to install litellm package to use LiteLLM models")

super().__init__()
self.model = model
self.options = options or {}
self.api_base = api_base
self.api_key = api_key
self.api_version = api_version

async def get_embeddings(self, data: List[str]) -> List[List[float]]:
"""
Creates embeddings for the given strings.

Args:
data: List of strings to get embeddings for.

Returns:
List of embeddings for the given strings.

Raises:
EmbeddingConnectionError: If there is a connection error with the embedding API.
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
try:
response = await litellm.aembedding(
input=data,
model=self.model,
api_base=self.api_base,
api_key=self.api_key,
api_version=self.api_version,
**self.options,
)
except litellm.openai.APIConnectionError as exc:
raise EmbeddingConnectionError() from exc
except litellm.openai.APIStatusError as exc:
raise EmbeddingStatusError(exc.message, exc.status_code) from exc
except litellm.openai.APIResponseValidationError as exc:
raise EmbeddingResponseError() from exc

return [embedding["embedding"] for embedding in response.data]
Loading
Loading