diff --git a/.libraries-whitelist.txt b/.libraries-whitelist.txt index ac3a8761..468e4175 100644 --- a/.libraries-whitelist.txt +++ b/.libraries-whitelist.txt @@ -1,3 +1,4 @@ pkg_resources psycopg2-binary -tiktoken \ No newline at end of file +tiktoken +chroma-hnswlib \ No newline at end of file diff --git a/.license-whitelist.txt b/.license-whitelist.txt index 8a46f385..c01edb7b 100644 --- a/.license-whitelist.txt +++ b/.license-whitelist.txt @@ -22,4 +22,5 @@ Python Software Foundation License, MIT License Unlicense Proprietary License Historical Permission Notice and Disclaimer (HPND) -ISC \ No newline at end of file +ISC +Apache License v2.0 \ No newline at end of file diff --git a/docs/how-to/use_chromadb_store.md b/docs/how-to/use_chromadb_store.md new file mode 100644 index 00000000..b715d8e8 --- /dev/null +++ b/docs/how-to/use_chromadb_store.md @@ -0,0 +1,70 @@ +# How-To: Use Chromadb to Store Similarity Index + +[ChromadbStore][dbally.similarity.ChromadbStore] allows to use [Chroma vector database](https://docs.trychroma.com/api-reference#methods-on-collection) as a store inside the [SimilarityIndex][dbally.similarity.SimilarityIndex]. With this feature, when someone searches for 'Show my flights to the USA' and we have 'United States' stored in our database as the country's value, the system will recognize the similarity and convert the query from 'get_flights(to="USA")' to 'get_flights(to="United States")'. + + +## Prerequisites + +To use Chromadb with db-ally you need to install the chromadb extension + +```python +pip install dbally[chromadb] +``` + +Let's say we have already implemented our [SimilarityFetcher](../how-to/use_custom_similarity_fetcher.md) + +```python +class DummyCountryFetcher(SimilarityFetcher): + async def fetch(self): + return ["United States", "Canada", "Mexico"] +``` + +Next, we need to define `Chromadb.Client`. You can [run Chromadb on your local machine](https://docs.trychroma.com/usage-guide#initiating-a-persistent-chroma-client) + +```python +import chromadb + +chroma_client = chromadb.PersistentClient(path="/path/to/save/to") +``` + +or [set up Chromadb in the client/server mode](https://docs.trychroma.com/usage-guide#running-chroma-in-clientserver-mode) + +```python +chroma_client = chromadb.HttpClient(host='localhost', port=8000) +``` + +Next, you can either use [one of dbally embedding clients][dbally.embedding_client.EmbeddingClient], such as [OpenAiEmbeddingClient][dbally.embedding_client.OpenAiEmbeddingClient] + +```python +from dbally.embedding_client import OpenAiEmbeddingClient + +embedding_client=OpenAiEmbeddingClient( + api_key="your-api-key", + ) + +``` + +or [Chromadb embedding functions](https://docs.trychroma.com/embeddings) + +``` +from chromadb.utils import embedding_functions +embedding_client = embedding_functions.DefaultEmbeddingFunction() +``` + +to define your [`ChromadbStore`][dbally.similarity.ChromadbStore]. + +```python +store = ChromadbStore(index_name="myChromaIndex", chroma_client=chroma_client, embedding_function=embedding_client) +``` + +After this setup, you can initialize the SimilarityIndex + +```python +from typing import Annotated + +country_similarity = SimilarityIndex(store, DummyCountryFetcher()) + + +``` + +and [update it and find the closest matches in the same way as in built-in similarity indices](./use_custom_similarity_store.md/#using-the-similarity-index) . diff --git a/docs/how-to/use_custom_similarity_fetcher.md b/docs/how-to/use_custom_similarity_fetcher.md index 0e5feb53..06a01197 100644 --- a/docs/how-to/use_custom_similarity_fetcher.md +++ b/docs/how-to/use_custom_similarity_fetcher.md @@ -13,15 +13,15 @@ To use a similarity index with data from a custom source, you need to create a c ## Creating a Custom Fetcher -To craft a custom fetcher, you need to create a class extending the `AbstractFetcher` class provided by db-ally. The `AbstractFetcher` class possesses a single asynchronous method, `fetch`, which you need to implement. This method should give back a list of strings representing all possible values from your data source. +To craft a custom fetcher, you need to create a class extending the `SimilarityFetcher` class provided by db-ally. The `SimilarityFetcher` class possesses a single asynchronous method, `fetch`, which you need to implement. This method should give back a list of strings representing all possible values from your data source. For example, if you wish to index the list of dog breeds from the web API provided by [dog.ceo](https://dog.ceo/), you can create a fetcher like this: ```python -from dbally.similarity.fetcher import AbstractFetcher +from dbally.similarity.fetcher import SimilarityFetcher import requests -class DogBreedsFetcher(AbstractFetcher): +class DogBreedsFetcher(SimilarityFetcher): async def fetch(self): response = requests.get('https://dog.ceo/api/breeds/list/all').json() breeds = response['message'].keys() diff --git a/docs/reference/similarity/similarity_store/chroma.md b/docs/reference/similarity/similarity_store/chroma.md new file mode 100644 index 00000000..542eb3fe --- /dev/null +++ b/docs/reference/similarity/similarity_store/chroma.md @@ -0,0 +1,7 @@ +#ChromadbStore + +!!! info + To see example of using ChromadbStore visit: [How-To: Use Chromadb to Store Similarity Index](../../../how-to/use_chromadb_store.md) + + +::: dbally.similarity.ChromadbStore \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index abd95262..978efa67 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -21,6 +21,7 @@ nav: - how-to/custom_views.md - Using similarity indexes: - how-to/use_custom_similarity_fetcher.md + - how-to/use_chromadb_store.md - how-to/use_custom_similarity_store.md - how-to/update_similarity_indexes.md - how-to/log_runs_to_langsmith.md @@ -54,6 +55,7 @@ nav: - Store: - reference/similarity/similarity_store/index.md - reference/similarity/similarity_store/faiss.md + - reference/similarity/similarity_store/chroma.md - Fetcher: - reference/similarity/similarity_fetcher/index.md - reference/similarity/similarity_fetcher/sqlalchemy.md diff --git a/requirements-dev.txt b/requirements-dev.txt index 07b7382b..6f92eed8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # Requirements as needed for development for this project. # --------------------------------------------------------- # Install current project --e.[openai,transformers] +-e.[openai,transformers,chromadb] # developer tools: pre-commit pytest>=6.2.5 diff --git a/setup.cfg b/setup.cfg index 471f708e..ebcb0171 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,6 +62,8 @@ benchmark = pydantic-core~=2.16.2 pydantic-settings~=2.0.3 psycopg2-binary~=2.9.9 +chromadb = + chromadb>=0.4.24 [options.packages.find] where = src diff --git a/src/dbally/similarity/__init__.py b/src/dbally/similarity/__init__.py index 2c6f0d26..5ce05035 100644 --- a/src/dbally/similarity/__init__.py +++ b/src/dbally/similarity/__init__.py @@ -9,6 +9,11 @@ except ImportError: pass +try: + from .chroma_store import ChromadbStore +except ImportError: + pass + __all__ = [ "AbstractSimilarityIndex", "SimilarityIndex", @@ -17,4 +22,5 @@ "SimilarityStore", "SimilarityFetcher", "FaissStore", + "ChromadbStore", ] diff --git a/src/dbally/similarity/chroma_store.py b/src/dbally/similarity/chroma_store.py new file mode 100644 index 00000000..f1684a44 --- /dev/null +++ b/src/dbally/similarity/chroma_store.py @@ -0,0 +1,96 @@ +from hashlib import sha256 +from typing import List, Literal, Optional, Union + +import chromadb + +from dbally.embedding_client.base import EmbeddingClient +from dbally.similarity.store import SimilarityStore + + +class ChromadbStore(SimilarityStore): + """Class that stores text embeddings using [Chroma](https://docs.trychroma.com/)""" + + def __init__( + self, + index_name: str, + chroma_client: chromadb.Client, + embedding_function: Union[EmbeddingClient, chromadb.EmbeddingFunction], + max_distance: Optional[float] = None, + distance_method: Literal["l2", "ip", "cosine"] = "l2", + ): + super().__init__() + self.index_name = index_name + self.chroma_client = chroma_client + self.embedding_function = embedding_function + self.max_distance = max_distance + + self._metadata = {"hnsw:space": distance_method} + + def _get_chroma_collection(self) -> chromadb.Collection: + """Based on the selected embedding_function chooses how to retrieve the Chromadb collection. + If collection doesn't exist it creates one. + + Returns: + chromadb.Collection: Retrieved collection + """ + if isinstance(self.embedding_function, EmbeddingClient): + return self.chroma_client.get_or_create_collection(name=self.index_name, metadata=self._metadata) + + return self.chroma_client.get_or_create_collection( + name=self.index_name, metadata=self._metadata, embedding_function=self.embedding_function + ) + + def _return_best_match(self, retrieved: dict) -> Optional[str]: + """Based on the retrieved data returns the best match or None if no match is found. + + Args: + retrieved: Retrieved data, with a column first format + + Returns: + The best match or None if no match is found + """ + if self.max_distance is None or retrieved["distances"][0][0] <= self.max_distance: + return retrieved["documents"][0][0] + + return None + + async def store(self, data: List[str]) -> None: + """ + Fills chroma collection with embeddings of provided string. As the id uses hash value of the string. + + Args: + data: The data to store. + """ + + ids = [sha256(x.encode("utf-8")).hexdigest() for x in data] + + collection = self._get_chroma_collection() + + if isinstance(self.embedding_function, EmbeddingClient): + embeddings = await self.embedding_function.get_embeddings(data) + + collection.add(ids=ids, embeddings=embeddings, documents=data) + else: + collection.add(ids=ids, documents=data) + + async def find_similar(self, text: str) -> Optional[str]: + """ + Finds the most similar text in the chroma collection or returns None if the most similar text + has distance bigger than `self.max_distance`. + + Args: + text: The text to find similar to. + + Returns: + The most similar text or None if no similar text is found. + """ + + collection = self._get_chroma_collection() + + if isinstance(self.embedding_function, EmbeddingClient): + embedding = await self.embedding_function.get_embeddings([text]) + retrieved = collection.query(query_embeddings=embedding, n_results=1) + else: + retrieved = collection.query(query_texts=[text], n_results=1) + + return self._return_best_match(retrieved) diff --git a/tests/integration/test_index_with_chroma.py b/tests/integration/test_index_with_chroma.py new file mode 100644 index 00000000..b28ec1db --- /dev/null +++ b/tests/integration/test_index_with_chroma.py @@ -0,0 +1,45 @@ +import chromadb +import pytest +from chromadb import Documents, EmbeddingFunction, Embeddings + +from dbally.embedding_client.base import EmbeddingClient +from dbally.similarity import ChromadbStore +from dbally.similarity.fetcher import SimilarityFetcher +from dbally.similarity.index import SimilarityIndex + + +class DummyCountryFetcher(SimilarityFetcher): + async def fetch(self): + return ["United States", "Canada", "Mexico"] + + +MAPPING = {"United States": [1, 1, 1], "Canada": [-1, -1, -1], "Mexico": [1, 1, -2], "USA": [0.2, 1, 1]} + + +class DummyEmbeddingClient(EmbeddingClient): + async def get_embeddings(self, data): + return [MAPPING[d] for d in data] + + +class DummyEmbeddingFunction(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + return [MAPPING[d] for d in input] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("embedding_function", [DummyEmbeddingClient(), DummyEmbeddingFunction()]) +async def test_integration_embedding_client(embedding_function): + chroma_client = chromadb.Client() + + store = ChromadbStore(index_name="test", chroma_client=chroma_client, embedding_function=embedding_function) + fetcher = DummyCountryFetcher() + + index = SimilarityIndex(store, fetcher) + + await index.update() + + assert store._get_chroma_collection().count() == 3 + + similar = await index.similar("USA") + + assert similar == "United States" diff --git a/tests/unit/similarity/test_chroma.py b/tests/unit/similarity/test_chroma.py new file mode 100644 index 00000000..f54301ed --- /dev/null +++ b/tests/unit/similarity/test_chroma.py @@ -0,0 +1,114 @@ +from hashlib import sha256 +from unittest.mock import AsyncMock, Mock + +import chromadb +import pytest + +from dbally.embedding_client import EmbeddingClient +from dbally.similarity import ChromadbStore + +DEFAULT_METADATA = {"hnsw:space": "l2"} +TEST_NAME = "test" + + +@pytest.fixture +def chroma_store_client(): + store = ChromadbStore(index_name=TEST_NAME, chroma_client=Mock(), embedding_function=Mock(spec=EmbeddingClient)) + store.embedding_function.get_embeddings = AsyncMock(return_value="test_embedding") + return store + + +@pytest.fixture +def chroma_store_function(): + return ChromadbStore( + index_name=TEST_NAME, chroma_client=Mock(), embedding_function=Mock(spec=chromadb.EmbeddingFunction) + ) + + +def test_chroma_get_chroma_collection_embedding_chroma_client(chroma_store_client): + chroma_store_client._get_chroma_collection() + chroma_store_client.chroma_client.get_or_create_collection.assert_called_with( + name=TEST_NAME, metadata=DEFAULT_METADATA + ) + + +def test_chroma_get_chroma_collection_chroma_embedding_function(chroma_store_function): + chroma_store_function._get_chroma_collection() + chroma_store_function.chroma_client.get_or_create_collection.assert_called_with( + name=TEST_NAME, metadata=DEFAULT_METADATA, embedding_function=chroma_store_function.embedding_function + ) + + +RETRIEVED = {"distances": [[0.4]], "documents": [["test"]]} + + +def get_mocked_collection(mock_client_store: Mock): + mock_collection = Mock() + mock_collection.query = Mock(return_value=RETRIEVED) + + mock_client_store._get_chroma_collection = Mock(return_value=mock_collection) + + return mock_collection + + +@pytest.mark.asyncio +async def test_store_embedding_client(chroma_store_client): + mock_collection = get_mocked_collection(chroma_store_client) + + await chroma_store_client.store(["test"]) + chroma_store_client.embedding_function.get_embeddings.assert_called_with(["test"]) + mock_collection.add.assert_called_with( + ids=[sha256(b"test").hexdigest()], embeddings="test_embedding", documents=["test"] + ) + + +@pytest.mark.asyncio +async def test_store_chroma_embedding_function(chroma_store_function): + mock_collection = get_mocked_collection(chroma_store_function) + + await chroma_store_function.store(["test"]) + mock_collection.add.assert_called_with(ids=[sha256(b"test").hexdigest()], documents=["test"]) + + +@pytest.mark.asyncio +async def test_find_similar_embedding_client(chroma_store_client): + mock_collection = get_mocked_collection(chroma_store_client) + + result = await chroma_store_client.find_similar("test") + + chroma_store_client.embedding_function.get_embeddings.assert_called_with(["test"]) + mock_collection.query.assert_called_with(query_embeddings="test_embedding", n_results=1) + + assert result == "test" + + +@pytest.mark.asyncio +async def test_find_similar_chroma_embedding_function(chroma_store_function): + mock_collection = get_mocked_collection(chroma_store_function) + + result = await chroma_store_function.find_similar("test") + + mock_collection.query.assert_called_with(query_texts=["test"], n_results=1) + + assert result == "test" + + +def test_return_best_match_max_distance_is_none(chroma_store_client): + chroma_store_client.max_distance = None + result = chroma_store_client._return_best_match(RETRIEVED) + + assert result == "test" + + +def test_return_best_match_max_distance_is_not_acceptable(chroma_store_client): + chroma_store_client.max_distance = 0.3 + result = chroma_store_client._return_best_match(RETRIEVED) + + assert result is None + + +def test_return_best_match_max_distance_is_acceptable(chroma_store_client): + chroma_store_client.max_distance = 0.5 + result = chroma_store_client._return_best_match(RETRIEVED) + + assert result == "test"