Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Integration with Chromadb #18

Merged
merged 14 commits into from
Apr 24, 2024
3 changes: 2 additions & 1 deletion .libraries-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pkg_resources
psycopg2-binary
tiktoken
tiktoken
chroma-hnswlib
3 changes: 2 additions & 1 deletion .license-whitelist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ Python Software Foundation License, MIT License
Unlicense
Proprietary License
Historical Permission Notice and Disclaimer (HPND)
ISC
ISC
Apache License v2.0
70 changes: 70 additions & 0 deletions docs/how-to/use_chromadb_store.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# How-To: Use Chromadb to Store Similarity Index

[ChromadbStore][dbally.similarity.ChromadbStore] allows to use [Chroma vector database](https://docs.trychroma.com/api-reference#methods-on-collection) as a store inside the [SimilarityIndex][dbally.similarity.SimilarityIndex]. With this feature, when someone searches for 'Show my flights to the USA' and we have 'United States' stored in our database as the country's value, the system will recognize the similarity and convert the query from 'get_flights(to="USA")' to 'get_flights(to="United States")'.


## Prerequisites

To use Chromadb with db-ally you need to install the chromadb extension

```python
pip install dbally[chromadb]
```

Let's say we have already implemented our [SimilarityFetcher](../how-to/use_custom_similarity_fetcher.md)

```python
class DummyCountryFetcher(SimilarityFetcher):
async def fetch(self):
return ["United States", "Canada", "Mexico"]
```

Next, we need to define `Chromadb.Client`. You can [run Chromadb on your local machine](https://docs.trychroma.com/usage-guide#initiating-a-persistent-chroma-client)

```python
import chromadb

chroma_client = chromadb.PersistentClient(path="/path/to/save/to")
```

or [set up Chromadb in the client/server mode](https://docs.trychroma.com/usage-guide#running-chroma-in-clientserver-mode)

```python
chroma_client = chromadb.HttpClient(host='localhost', port=8000)
```

Next, you can either use [one of dbally embedding clients][dbally.embedding_client.EmbeddingClient], such as [OpenAiEmbeddingClient][dbally.embedding_client.OpenAiEmbeddingClient]

```python
from dbally.embedding_client import OpenAiEmbeddingClient

embedding_client=OpenAiEmbeddingClient(
api_key="your-api-key",
)

```

or [Chromadb embedding functions](https://docs.trychroma.com/embeddings)

```
from chromadb.utils import embedding_functions
embedding_client = embedding_functions.DefaultEmbeddingFunction()
```

to define your [`ChromadbStore`][dbally.similarity.ChromadbStore].

```python
store = ChromadbStore(index_name="myChromaIndex", chroma_client=chroma_client, embedding_function=embedding_client)
```

After this setup, you can initialize the SimilarityIndex

```python
from typing import Annotated

country_similarity = SimilarityIndex(store, DummyCountryFetcher())


```
ds-sebastianchwilczynski marked this conversation as resolved.
Show resolved Hide resolved

and [update it and find the closest matches in the same way as in built-in similarity indices](./use_custom_similarity_store.md/#using-the-similarity-index) .
6 changes: 3 additions & 3 deletions docs/how-to/use_custom_similarity_fetcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ To use a similarity index with data from a custom source, you need to create a c

## Creating a Custom Fetcher

To craft a custom fetcher, you need to create a class extending the `AbstractFetcher` class provided by db-ally. The `AbstractFetcher` class possesses a single asynchronous method, `fetch`, which you need to implement. This method should give back a list of strings representing all possible values from your data source.
To craft a custom fetcher, you need to create a class extending the `SimilarityFetcher` class provided by db-ally. The `SimilarityFetcher` class possesses a single asynchronous method, `fetch`, which you need to implement. This method should give back a list of strings representing all possible values from your data source.

For example, if you wish to index the list of dog breeds from the web API provided by [dog.ceo](https://dog.ceo/), you can create a fetcher like this:

```python
from dbally.similarity.fetcher import AbstractFetcher
from dbally.similarity.fetcher import SimilarityFetcher
import requests

class DogBreedsFetcher(AbstractFetcher):
class DogBreedsFetcher(SimilarityFetcher):
async def fetch(self):
response = requests.get('https://dog.ceo/api/breeds/list/all').json()
breeds = response['message'].keys()
Expand Down
7 changes: 7 additions & 0 deletions docs/reference/similarity/similarity_store/chroma.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ChromadbStore

!!! info
To see example of using ChromadbStore visit: [How-To: Use Chromadb to Store Similarity Index](../../../how-to/use_chromadb_store.md)


::: dbally.similarity.ChromadbStore
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ nav:
- how-to/custom_views.md
- Using similarity indexes:
- how-to/use_custom_similarity_fetcher.md
- how-to/use_chromadb_store.md
- how-to/use_custom_similarity_store.md
- how-to/update_similarity_indexes.md
- how-to/log_runs_to_langsmith.md
Expand Down Expand Up @@ -54,6 +55,7 @@ nav:
- Store:
- reference/similarity/similarity_store/index.md
- reference/similarity/similarity_store/faiss.md
- reference/similarity/similarity_store/chroma.md
- Fetcher:
- reference/similarity/similarity_fetcher/index.md
- reference/similarity/similarity_fetcher/sqlalchemy.md
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Requirements as needed for development for this project.
# ---------------------------------------------------------
# Install current project
-e.[openai,transformers]
-e.[openai,transformers,chromadb]
# developer tools:
pre-commit
pytest>=6.2.5
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ benchmark =
pydantic-core~=2.16.2
pydantic-settings~=2.0.3
psycopg2-binary~=2.9.9
chromadb =
chromadb>=0.4.24

[options.packages.find]
where = src
Expand Down
6 changes: 6 additions & 0 deletions src/dbally/similarity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
except ImportError:
pass

try:
from .chroma_store import ChromadbStore
except ImportError:
pass

__all__ = [
"AbstractSimilarityIndex",
"SimilarityIndex",
Expand All @@ -17,4 +22,5 @@
"SimilarityStore",
"SimilarityFetcher",
"FaissStore",
"ChromadbStore",
]
96 changes: 96 additions & 0 deletions src/dbally/similarity/chroma_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from hashlib import sha256
from typing import List, Literal, Optional, Union

import chromadb

from dbally.embedding_client.base import EmbeddingClient
from dbally.similarity.store import SimilarityStore


class ChromadbStore(SimilarityStore):
"""Class that stores text embeddings using [Chroma](https://docs.trychroma.com/)"""

def __init__(
self,
index_name: str,
chroma_client: chromadb.Client,
embedding_function: Union[EmbeddingClient, chromadb.EmbeddingFunction],
max_distance: Optional[float] = None,
distance_method: Literal["l2", "ip", "cosine"] = "l2",
):
super().__init__()
self.index_name = index_name
self.chroma_client = chroma_client
self.embedding_function = embedding_function
self.max_distance = max_distance

self._metadata = {"hnsw:space": distance_method}

def _get_chroma_collection(self) -> chromadb.Collection:
"""Based on the selected embedding_function chooses how to retrieve the Chromadb collection.
If collection doesn't exist it creates one.

Returns:
chromadb.Collection: Retrieved collection
"""
if isinstance(self.embedding_function, EmbeddingClient):
return self.chroma_client.get_or_create_collection(name=self.index_name, metadata=self._metadata)

return self.chroma_client.get_or_create_collection(
name=self.index_name, metadata=self._metadata, embedding_function=self.embedding_function
)

def _return_best_match(self, retrieved: dict) -> Optional[str]:
"""Based on the retrieved data returns the best match or None if no match is found.

Args:
retrieved: Retrieved data, with a column first format

Returns:
The best match or None if no match is found
"""
if self.max_distance is None or retrieved["distances"][0][0] <= self.max_distance:
return retrieved["documents"][0][0]

return None

async def store(self, data: List[str]) -> None:
"""
Fills chroma collection with embeddings of provided string. As the id uses hash value of the string.

Args:
data: The data to store.
"""

ids = [sha256(x.encode("utf-8")).hexdigest() for x in data]

collection = self._get_chroma_collection()

if isinstance(self.embedding_function, EmbeddingClient):
embeddings = await self.embedding_function.get_embeddings(data)

collection.add(ids=ids, embeddings=embeddings, documents=data)
else:
collection.add(ids=ids, documents=data)

async def find_similar(self, text: str) -> Optional[str]:
"""
Finds the most similar text in the chroma collection or returns None if the most similar text
has distance bigger than `self.max_distance`.

Args:
text: The text to find similar to.

Returns:
The most similar text or None if no similar text is found.
"""

collection = self._get_chroma_collection()

if isinstance(self.embedding_function, EmbeddingClient):
embedding = await self.embedding_function.get_embeddings([text])
retrieved = collection.query(query_embeddings=embedding, n_results=1)
else:
retrieved = collection.query(query_texts=[text], n_results=1)

return self._return_best_match(retrieved)
45 changes: 45 additions & 0 deletions tests/integration/test_index_with_chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import chromadb
import pytest
from chromadb import Documents, EmbeddingFunction, Embeddings

from dbally.embedding_client.base import EmbeddingClient
from dbally.similarity import ChromadbStore
from dbally.similarity.fetcher import SimilarityFetcher
from dbally.similarity.index import SimilarityIndex


class DummyCountryFetcher(SimilarityFetcher):
async def fetch(self):
return ["United States", "Canada", "Mexico"]


MAPPING = {"United States": [1, 1, 1], "Canada": [-1, -1, -1], "Mexico": [1, 1, -2], "USA": [0.2, 1, 1]}


class DummyEmbeddingClient(EmbeddingClient):
async def get_embeddings(self, data):
return [MAPPING[d] for d in data]


class DummyEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
return [MAPPING[d] for d in input]


@pytest.mark.asyncio
@pytest.mark.parametrize("embedding_function", [DummyEmbeddingClient(), DummyEmbeddingFunction()])
async def test_integration_embedding_client(embedding_function):
chroma_client = chromadb.Client()

store = ChromadbStore(index_name="test", chroma_client=chroma_client, embedding_function=embedding_function)
fetcher = DummyCountryFetcher()

index = SimilarityIndex(store, fetcher)

await index.update()

assert store._get_chroma_collection().count() == 3

similar = await index.similar("USA")

assert similar == "United States"
Loading
Loading