-
Notifications
You must be signed in to change notification settings - Fork 5
Feat: Integration with Chromadb #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
eac324c
feat: chroma db support
0927bdd
fix: hash issues
d8365e6
test: integration chroma testing
00e9644
fix: how-to old class naming
97db953
docs: add reference and how-to
450c538
fix: missing dependencies
fecedee
fix: test name repetition
fdbc730
fix: chroma dependency unknown licence
3776682
fix: chroma dependency Apache License v2.0
3673e0c
fix: pr issues
f7cb812
Merge branch 'main' into sc/chromadb
04d6de1
fix: pr issues 2
7dfa219
Update docs/how-to/use_chromadb_store.md
ds-sebastianchwilczynski 13c5a33
rename embedding calculator
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
pkg_resources | ||
psycopg2-binary | ||
tiktoken | ||
tiktoken | ||
chroma-hnswlib |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# How-To: Use Chromadb to Store Similarity Index | ||
|
||
[ChromadbStore][dbally.similarity.ChromadbStore] allows to use [Chroma vector database](https://docs.trychroma.com/api-reference#methods-on-collection) as a store inside the [SimilarityIndex][dbally.similarity.SimilarityIndex]. With this feature, when someone searches for 'Show my flights to the USA' and we have 'United States' stored in our database as the country's value, the system will recognize the similarity and convert the query from 'get_flights(to="USA")' to 'get_flights(to="United States")'. | ||
|
||
|
||
## Prerequisites | ||
|
||
To use Chromadb with db-ally you need to install the chromadb extension | ||
|
||
```python | ||
pip install dbally[chromadb] | ||
``` | ||
|
||
Let's say we have already implemented our [SimilarityFetcher](../how-to/use_custom_similarity_fetcher.md) | ||
|
||
```python | ||
class DummyCountryFetcher(SimilarityFetcher): | ||
async def fetch(self): | ||
return ["United States", "Canada", "Mexico"] | ||
``` | ||
|
||
Next, we need to define `Chromadb.Client`. You can [run Chromadb on your local machine](https://docs.trychroma.com/usage-guide#initiating-a-persistent-chroma-client) | ||
|
||
```python | ||
import chromadb | ||
|
||
chroma_client = chromadb.PersistentClient(path="/path/to/save/to") | ||
``` | ||
|
||
or [set up Chromadb in the client/server mode](https://docs.trychroma.com/usage-guide#running-chroma-in-clientserver-mode) | ||
|
||
```python | ||
chroma_client = chromadb.HttpClient(host='localhost', port=8000) | ||
``` | ||
|
||
Next, you can either use [one of dbally embedding clients][dbally.embedding_client.EmbeddingClient], such as [OpenAiEmbeddingClient][dbally.embedding_client.OpenAiEmbeddingClient] | ||
|
||
```python | ||
from dbally.embedding_client import OpenAiEmbeddingClient | ||
|
||
embedding_client=OpenAiEmbeddingClient( | ||
api_key="your-api-key", | ||
) | ||
|
||
``` | ||
|
||
or [Chromadb embedding functions](https://docs.trychroma.com/embeddings) | ||
|
||
``` | ||
from chromadb.utils import embedding_functions | ||
embedding_client = embedding_functions.DefaultEmbeddingFunction() | ||
``` | ||
|
||
to define your [`ChromadbStore`][dbally.similarity.ChromadbStore]. | ||
|
||
```python | ||
store = ChromadbStore(index_name="myChromaIndex", chroma_client=chroma_client, embedding_function=embedding_client) | ||
``` | ||
|
||
After this setup, you can initialize the SimilarityIndex | ||
|
||
```python | ||
from typing import Annotated | ||
|
||
country_similarity = SimilarityIndex(store, DummyCountryFetcher()) | ||
|
||
|
||
``` | ||
|
||
and [update it and find the closest matches in the same way as in built-in similarity indices](./use_custom_similarity_store.md/#using-the-similarity-index) . |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#ChromadbStore | ||
|
||
!!! info | ||
To see example of using ChromadbStore visit: [How-To: Use Chromadb to Store Similarity Index](../../../how-to/use_chromadb_store.md) | ||
|
||
|
||
::: dbally.similarity.ChromadbStore |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from hashlib import sha256 | ||
from typing import List, Literal, Optional, Union | ||
|
||
import chromadb | ||
|
||
from dbally.embedding_client.base import EmbeddingClient | ||
from dbally.similarity.store import SimilarityStore | ||
|
||
|
||
class ChromadbStore(SimilarityStore): | ||
"""Class that stores text embeddings using [Chroma](https://docs.trychroma.com/)""" | ||
|
||
def __init__( | ||
self, | ||
index_name: str, | ||
chroma_client: chromadb.Client, | ||
embedding_function: Union[EmbeddingClient, chromadb.EmbeddingFunction], | ||
max_distance: Optional[float] = None, | ||
distance_method: Literal["l2", "ip", "cosine"] = "l2", | ||
): | ||
super().__init__() | ||
self.index_name = index_name | ||
self.chroma_client = chroma_client | ||
self.embedding_function = embedding_function | ||
self.max_distance = max_distance | ||
|
||
self._metadata = {"hnsw:space": distance_method} | ||
|
||
def _get_chroma_collection(self) -> chromadb.Collection: | ||
"""Based on the selected embedding_function chooses how to retrieve the Chromadb collection. | ||
If collection doesn't exist it creates one. | ||
|
||
Returns: | ||
chromadb.Collection: Retrieved collection | ||
""" | ||
if isinstance(self.embedding_function, EmbeddingClient): | ||
return self.chroma_client.get_or_create_collection(name=self.index_name, metadata=self._metadata) | ||
|
||
return self.chroma_client.get_or_create_collection( | ||
name=self.index_name, metadata=self._metadata, embedding_function=self.embedding_function | ||
) | ||
|
||
def _return_best_match(self, retrieved: dict) -> Optional[str]: | ||
"""Based on the retrieved data returns the best match or None if no match is found. | ||
|
||
Args: | ||
retrieved: Retrieved data, with a column first format | ||
|
||
Returns: | ||
The best match or None if no match is found | ||
""" | ||
if self.max_distance is None or retrieved["distances"][0][0] <= self.max_distance: | ||
return retrieved["documents"][0][0] | ||
|
||
return None | ||
|
||
async def store(self, data: List[str]) -> None: | ||
""" | ||
Fills chroma collection with embeddings of provided string. As the id uses hash value of the string. | ||
|
||
Args: | ||
data: The data to store. | ||
""" | ||
|
||
ids = [sha256(x.encode("utf-8")).hexdigest() for x in data] | ||
|
||
collection = self._get_chroma_collection() | ||
|
||
if isinstance(self.embedding_function, EmbeddingClient): | ||
embeddings = await self.embedding_function.get_embeddings(data) | ||
|
||
collection.add(ids=ids, embeddings=embeddings, documents=data) | ||
else: | ||
collection.add(ids=ids, documents=data) | ||
|
||
async def find_similar(self, text: str) -> Optional[str]: | ||
""" | ||
Finds the most similar text in the chroma collection or returns None if the most similar text | ||
has distance bigger than `self.max_distance`. | ||
|
||
Args: | ||
text: The text to find similar to. | ||
|
||
Returns: | ||
The most similar text or None if no similar text is found. | ||
""" | ||
|
||
collection = self._get_chroma_collection() | ||
|
||
if isinstance(self.embedding_function, EmbeddingClient): | ||
embedding = await self.embedding_function.get_embeddings([text]) | ||
retrieved = collection.query(query_embeddings=embedding, n_results=1) | ||
else: | ||
retrieved = collection.query(query_texts=[text], n_results=1) | ||
|
||
return self._return_best_match(retrieved) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import chromadb | ||
import pytest | ||
from chromadb import Documents, EmbeddingFunction, Embeddings | ||
|
||
from dbally.embedding_client.base import EmbeddingClient | ||
from dbally.similarity import ChromadbStore | ||
from dbally.similarity.fetcher import SimilarityFetcher | ||
from dbally.similarity.index import SimilarityIndex | ||
|
||
|
||
class DummyCountryFetcher(SimilarityFetcher): | ||
async def fetch(self): | ||
return ["United States", "Canada", "Mexico"] | ||
|
||
|
||
MAPPING = {"United States": [1, 1, 1], "Canada": [-1, -1, -1], "Mexico": [1, 1, -2], "USA": [0.2, 1, 1]} | ||
|
||
|
||
class DummyEmbeddingClient(EmbeddingClient): | ||
async def get_embeddings(self, data): | ||
return [MAPPING[d] for d in data] | ||
|
||
|
||
class DummyEmbeddingFunction(EmbeddingFunction): | ||
def __call__(self, input: Documents) -> Embeddings: | ||
return [MAPPING[d] for d in input] | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("embedding_function", [DummyEmbeddingClient(), DummyEmbeddingFunction()]) | ||
async def test_integration_embedding_client(embedding_function): | ||
chroma_client = chromadb.Client() | ||
|
||
store = ChromadbStore(index_name="test", chroma_client=chroma_client, embedding_function=embedding_function) | ||
fetcher = DummyCountryFetcher() | ||
|
||
index = SimilarityIndex(store, fetcher) | ||
|
||
await index.update() | ||
|
||
assert store._get_chroma_collection().count() == 3 | ||
|
||
similar = await index.similar("USA") | ||
|
||
assert similar == "United States" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.