Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Amazon Bedrock Embedding function #1361

Merged
merged 5 commits into from
Dec 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,55 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
return embeddings


class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
session: "boto3.Session", # Quote for forward reference
model_name: str = "amazon.titan-embed-text-v1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a cleaner API boundary here and I'd like to hear your thoughts.

This constructor could accept a boto3.Session and optional kwargs to pass into the client() call. That way we give users full access to the configuration options for both their Session and their client, and we have less code to maintain (we don't need to read from os.environ for example). I may be missing something here -- is there a reason we need to accept these and construct the Session ourselves? If not could we accept a Session and config kwargs?

Copy link
Contributor Author

@chezou chezou Dec 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with just passing a boto3.Session and necessary kwargs for the client() call instead of creating a session within the class.

I don't have a specific reason to create a boto3.Session inside the class. I tried to align with the existing codes. Looking at the other classes creates a client internally, and HuggingFaceEmbeddingFunction creates requests.Session within the class.

Copy link
Contributor Author

@chezou chezou Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, do you think we should support all the possible config kwargs as much as possible?

Looking at the boto3 document, I think most of the kwargs are covered by the session, so we may need to pass api_version, use_ssl, verify, and endpoint_url, at least. Or, we may postpone introducing those arguments at this moment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tentatively, I removed most of the redundant arguments without adding client-specific arguments. Let me know your thoughts on it.
eedf49d

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already looking way better -- thank you!

One more small ask and then we can check it in: could we have the constructor accept **kwargs and pass them directly to the client() we create? That would give us three wins:
1 - We don't have to manage the retry_config which I imagine some folks will want control over.
2 - Users can specify whatever they want to override in their Session config without us needing to make any code changes.
3 - We don't need to import boto3 in this codebase anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. Fixed by using **kwargs a53e415

**kwargs: Any,
):
"""Initialize AmazonBedrockEmbeddingFunction.

Args:
session (boto3.Session): The boto3 session to use.
model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1"
**kwargs: Additional arguments to pass to the boto3 client.

Example:
>>> import boto3
>>> session = boto3.Session(profile_name="profile", region_name="us-east-1")
>>> bedrock = AmazonBedrockEmbeddingFunction(session=session)
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = bedrock(texts)
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chezou, it would be nice to check if boto3 package is installed and if not guide to user with appropriate error message. Check other EFs as ref.

Copy link
Contributor Author

@chezou chezou Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally did checking it, but I followed @beggers suggestion to remove the dependency of boto3 from chroma itself.
See also #1361 (comment)

self._model_name = model_name

self._client = session.client(
service_name="bedrock-runtime",
**kwargs,
)

def __call__(self, input: Documents) -> Embeddings:
import json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to import json here instead of at top-level?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, #1562


accept = "application/json"
content_type = "application/json"
embeddings = []
for text in input:
input_body = {"inputText": text}
body = json.dumps(input_body)
response = self._client.invoke_model(
body=body,
modelId=self._model_name,
accept=accept,
contentType=content_type,
)
embedding = json.load(response.get("body")).get("embedding")
embeddings.append(embedding)
return embeddings


class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]):
"""
This class is used to get embeddings for a list of texts using the HuggingFace Embedding server (https://github.com/huggingface/text-embeddings-inference).
Expand Down