From 49a33b3c201fb6933b18b161d966668c9fb2fe09 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Wed, 8 Nov 2023 16:02:14 -0800 Subject: [PATCH 1/4] Add Amazon Bedrock Embedding function https://docs.aws.amazon.com/bedrock/latest/userguide/embeddings.html --- chromadb/utils/embedding_functions.py | 99 +++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 141964465a5..49cf9dcfdea 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -581,6 +581,105 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: return embeddings +class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__( + self, + profile_name: Optional[str] = None, + region: Optional[str] = None, + model_name: str = "amazon.titan-embed-text-v1", + ): + """Initialize AmazonBedrockEmbeddingFucntion. + + Args: + profile_name (str, optional): The name of a profile to use. If not given, then the default profile is used, defaults to None + region (str, optional): Default region when creating new connections, defaults to None + model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" + """ + + self._model_name = model_name + + try: + import boto3 + from botocore.config import Config + except ImportError: + raise ValueError( + "The boto3 python package is not installed. Please install it with `pip install boto3`" + ) + + if not region: + target_region = os.environ.get( + "AWS_REGION", os.environ.get("AWS_DEFAULT_REGION") + ) + else: + target_region = region + + session_kwargs = {"region_name": target_region} + client_kwargs = {**session_kwargs} + + if profile_name: + target_profile = profile_name + else: + target_profile = os.environ.get("AWS_PROFILE") + + if target_profile: + session_kwargs["profile_name"] = target_profile + + retry_config = Config( + region_name=target_region, + retries={ + "max_attempts": 10, + "mode": "standard", + }, + ) + + session = boto3.Session(**session_kwargs) + self._client = session.client( + service_name="bedrock-runtime", config=retry_config, **client_kwargs + ) + + def __call__(self, input: Documents) -> Embeddings: + """Get the embeddings for a list of texts. + + Args: + input (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> bedrock = AmazonBedrockEmbeddingFunction(profile_name="profile") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = bedrock(texts) + """ + embeddings = [] + for text in input: + embeddings.append(self._invoke(text)) + return embeddings + + def _invoke( + self, + text: str, + ) -> Embedding: + import json + + input_body = {"inputText": text} + body = json.dumps(input_body) + accept = "application/json" + content_type = "application/json" + try: + response = self._client.invoke_model( + body=body, + modelId=self._model_name, + accept=accept, + contentType=content_type, + ) + embedding = json.load(response.get("body")).get("embedding") + except Exception as e: + raise ValueError(f"Error raised by bedrock service: {e}") + + return embedding + + # List of all classes in this module _classes = [ name From 0b4bfd0ebc3a95e4df1a08761f283c1d109b6a2d Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Mon, 18 Dec 2023 16:04:27 -0800 Subject: [PATCH 2/4] Refactor based on review comments - Remove unnecessary function - Remove re-throw an exception - Move comment to the right under the class --- chromadb/utils/embedding_functions.py | 40 ++++++++------------------- 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 49cf9dcfdea..2dfb5892b7c 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -594,6 +594,11 @@ def __init__( profile_name (str, optional): The name of a profile to use. If not given, then the default profile is used, defaults to None region (str, optional): Default region when creating new connections, defaults to None model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" + + Example: + >>> bedrock = AmazonBedrockEmbeddingFunction(profile_name="profile") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = bedrock(texts) """ self._model_name = model_name @@ -638,35 +643,14 @@ def __init__( ) def __call__(self, input: Documents) -> Embeddings: - """Get the embeddings for a list of texts. - - Args: - input (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the texts. - - Example: - >>> bedrock = AmazonBedrockEmbeddingFunction(profile_name="profile") - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = bedrock(texts) - """ - embeddings = [] - for text in input: - embeddings.append(self._invoke(text)) - return embeddings - - def _invoke( - self, - text: str, - ) -> Embedding: import json - input_body = {"inputText": text} - body = json.dumps(input_body) accept = "application/json" content_type = "application/json" - try: + embeddings = [] + for text in input: + input_body = {"inputText": text} + body = json.dumps(input_body) response = self._client.invoke_model( body=body, modelId=self._model_name, @@ -674,10 +658,8 @@ def _invoke( contentType=content_type, ) embedding = json.load(response.get("body")).get("embedding") - except Exception as e: - raise ValueError(f"Error raised by bedrock service: {e}") - - return embedding + embeddings.append(embedding) + return embeddings # List of all classes in this module From eedf49d1717b6177563a15a282986236d2681908 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Mon, 18 Dec 2023 16:50:39 -0800 Subject: [PATCH 3/4] Pass boto3 session instead of building it in the function --- chromadb/utils/embedding_functions.py | 35 ++++++--------------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 2dfb5892b7c..00db9d4a355 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -584,19 +584,19 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): def __init__( self, - profile_name: Optional[str] = None, - region: Optional[str] = None, + session: 'boto3.Session', # Quote for forward reference model_name: str = "amazon.titan-embed-text-v1", ): - """Initialize AmazonBedrockEmbeddingFucntion. + """Initialize AmazonBedrockEmbeddingFunction. Args: - profile_name (str, optional): The name of a profile to use. If not given, then the default profile is used, defaults to None - region (str, optional): Default region when creating new connections, defaults to None + session (boto3.Session): The boto3 session to use. model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" Example: - >>> bedrock = AmazonBedrockEmbeddingFunction(profile_name="profile") + >>> import boto3 + >>> session = boto3.Session(profile_name="profile", region_name="us-east-1") + >>> bedrock = AmazonBedrockEmbeddingFunction(session=session) >>> texts = ["Hello, world!", "How are you?"] >>> embeddings = bedrock(texts) """ @@ -604,42 +604,21 @@ def __init__( self._model_name = model_name try: - import boto3 from botocore.config import Config except ImportError: raise ValueError( "The boto3 python package is not installed. Please install it with `pip install boto3`" ) - if not region: - target_region = os.environ.get( - "AWS_REGION", os.environ.get("AWS_DEFAULT_REGION") - ) - else: - target_region = region - - session_kwargs = {"region_name": target_region} - client_kwargs = {**session_kwargs} - - if profile_name: - target_profile = profile_name - else: - target_profile = os.environ.get("AWS_PROFILE") - - if target_profile: - session_kwargs["profile_name"] = target_profile - retry_config = Config( - region_name=target_region, retries={ "max_attempts": 10, "mode": "standard", }, ) - session = boto3.Session(**session_kwargs) self._client = session.client( - service_name="bedrock-runtime", config=retry_config, **client_kwargs + service_name="bedrock-runtime", config=retry_config, ) def __call__(self, input: Documents) -> Embeddings: From a53e415f8c311cf30ba2b541dc28444f6ed185a3 Mon Sep 17 00:00:00 2001 From: Aki Ariga Date: Tue, 19 Dec 2023 11:54:16 -0800 Subject: [PATCH 4/4] Delegate boto3 client configuration to the user by accepting kwargs --- chromadb/utils/embedding_functions.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 00db9d4a355..a78bf38e233 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -584,14 +584,16 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings: class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): def __init__( self, - session: 'boto3.Session', # Quote for forward reference + session: "boto3.Session", # Quote for forward reference model_name: str = "amazon.titan-embed-text-v1", + **kwargs: Any, ): """Initialize AmazonBedrockEmbeddingFunction. Args: session (boto3.Session): The boto3 session to use. model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" + **kwargs: Additional arguments to pass to the boto3 client. Example: >>> import boto3 @@ -603,22 +605,9 @@ def __init__( self._model_name = model_name - try: - from botocore.config import Config - except ImportError: - raise ValueError( - "The boto3 python package is not installed. Please install it with `pip install boto3`" - ) - - retry_config = Config( - retries={ - "max_attempts": 10, - "mode": "standard", - }, - ) - self._client = session.client( - service_name="bedrock-runtime", config=retry_config, + service_name="bedrock-runtime", + **kwargs, ) def __call__(self, input: Documents) -> Embeddings: