Skip to content

Add AWS client generator utility #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6d820d8
Add AWS client creation util
michaelnchin Feb 14, 2025
c485bb8
Merge branch 'main' into client-helper
michaelnchin Feb 14, 2025
89bfaab
Merge branch 'main' into client-helper
michaelnchin Feb 19, 2025
3724a27
Remove tests no longer accurate to new implementation
michaelnchin Feb 19, 2025
7251b77
Merge branch 'main' into client-helper
michaelnchin Feb 22, 2025
c989ba1
Merge branch 'main' into client-helper
michaelnchin Feb 25, 2025
4246497
Merge branch 'main' into client-helper
michaelnchin Mar 5, 2025
6f37abe
Add unit tests in new test_utils
michaelnchin Mar 6, 2025
1404c3d
Explicitly type return/params on tests
michaelnchin Mar 6, 2025
e1a225c
more test typing fixes
michaelnchin Mar 6, 2025
bd77a51
Merge branch 'main' into client-helper
michaelnchin Mar 11, 2025
2d7bf04
Merge branch 'main' into client-helper
michaelnchin Mar 12, 2025
8edd2ad
fmt
michaelnchin Mar 12, 2025
0f3edaf
missing import from merge
michaelnchin Mar 12, 2025
f9bc06a
Merge branch 'main' into client-helper
michaelnchin Mar 14, 2025
79edd12
Merge branch 'main' into client-helper
michaelnchin Mar 17, 2025
ce4f95f
Merge branch 'main' into client-helper
michaelnchin Mar 17, 2025
01b5334
Address #402: Create session only if creds passed
michaelnchin Mar 18, 2025
349a8ea
Revise unit tests for session creation change
michaelnchin Mar 18, 2025
4ecb005
Merge branch 'client-helper' of https://github.com/michaelnchin/langc…
michaelnchin Mar 18, 2025
396561a
Merge branch 'main' into client-helper
michaelnchin Mar 18, 2025
56c5af9
Update exception handling
michaelnchin Mar 21, 2025
f3a58fa
Merge branch 'client-helper' of https://github.com/michaelnchin/langc…
michaelnchin Mar 21, 2025
eff13be
Merge branch 'main' into client-helper
michaelnchin Mar 21, 2025
4ec1141
Update for 3coins@ feedback
michaelnchin Mar 25, 2025
7ea8a01
fmt
michaelnchin Mar 25, 2025
bf2d4ff
Merge branch 'main' into client-helper
michaelnchin Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 11 additions & 49 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64
import json
import logging
import os
import re
import warnings
from operator import itemgetter
Expand All @@ -22,7 +21,6 @@
cast,
)

import boto3
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseChatModel, LanguageModelInput
Expand Down Expand Up @@ -56,6 +54,7 @@
from typing_extensions import Self

from langchain_aws.function_calling import ToolsOutputParser
from langchain_aws.utils import create_aws_client

logger = logging.getLogger(__name__)
_BM = TypeVar("_BM", bound=BaseModel)
Expand Down Expand Up @@ -568,53 +567,16 @@ def validate_environment(self) -> Self:

# Skip creating new client if passed in constructor
if self.client is None:
creds = {
"aws_access_key_id": self.aws_access_key_id,
"aws_secret_access_key": self.aws_secret_access_key,
"aws_session_token": self.aws_session_token,
}
if creds["aws_access_key_id"] and creds["aws_secret_access_key"]:
session_params = {
k: v.get_secret_value() for k, v in creds.items() if v
}
elif any(creds.values()):
raise ValueError(
f"If any of aws_access_key_id, aws_secret_access_key, or "
f"aws_session_token are specified then both aws_access_key_id and "
f"aws_secret_access_key must be specified. Only received "
f"{(k for k, v in creds.items() if v)}."
)
elif self.credentials_profile_name is not None:
session_params = {"profile_name": self.credentials_profile_name}
else:
# use default credentials
session_params = {}

try:
session = boto3.Session(**session_params)

self.region_name = (
self.region_name
or os.getenv("AWS_REGION")
or os.getenv("AWS_DEFAULT_REGION")
or session.region_name
)

client_params = {
"endpoint_url": self.endpoint_url,
"config": self.config,
"region_name": self.region_name,
}
client_params = {k: v for k, v in client_params.items() if v}
self.client = session.client("bedrock-runtime", **client_params)
except ValueError as e:
raise ValueError(f"Error raised by bedrock service:\n\n{e}") from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error:\n\n{e}"
) from e
self.client = create_aws_client(
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
endpoint_url=self.endpoint_url,
config=self.config,
service_name="bedrock-runtime",
)

return self

Expand Down
90 changes: 53 additions & 37 deletions libs/aws/langchain_aws/embeddings/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from pydantic import BaseModel, ConfigDict, Field, model_validator
from langchain_core.utils import secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_aws.utils import create_aws_client

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -47,7 +50,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
client: Any = Field(default=None, exclude=True) #: :meta private:
"""Bedrock client."""
region_name: Optional[str] = None
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
"""The aws region e.g., `us-west-2`. Falls back to AWS_REGION/AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config in case it is not provided here.
"""

Expand All @@ -59,6 +62,44 @@ class BedrockEmbeddings(BaseModel, Embeddings):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""

aws_access_key_id: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_ACCESS_KEY_ID", default=None)
)
"""AWS access key id.

If provided, aws_secret_access_key must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

If not provided, will be read from 'AWS_ACCESS_KEY_ID' environment variable.
"""

aws_secret_access_key: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SECRET_ACCESS_KEY", default=None)
)
"""AWS secret_access_key.

If provided, aws_access_key_id must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

If not provided, will be read from 'AWS_SECRET_ACCESS_KEY' environment variable.
"""

aws_session_token: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SESSION_TOKEN", default=None)
)
"""AWS session token.

If provided, aws_access_key_id and aws_secret_access_key must also be provided.
Not required unless using temporary credentials.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

If not provided, will be read from 'AWS_SESSION_TOKEN' environment variable.
"""

model_id: str = "amazon.titan-embed-text-v1"
"""Id of the model to call, e.g., amazon.titan-embed-text-v1, this is
equivalent to the modelId property in the list-foundation-models api"""
Expand Down Expand Up @@ -88,42 +129,17 @@ def provider(self) -> str:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""

if self.client is not None:
return self

try:
import boto3

if self.credentials_profile_name is not None:
session = boto3.Session(profile_name=self.credentials_profile_name)
else:
# use default credentials
session = boto3.Session()

client_params = {}
if self.region_name:
client_params["region_name"] = self.region_name

if self.endpoint_url:
client_params["endpoint_url"] = self.endpoint_url

if self.config:
client_params["config"] = self.config

self.client = session.client("bedrock-runtime", **client_params)

except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
if self.client is None:
self.client = create_aws_client(
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
endpoint_url=self.endpoint_url,
config=self.config,
service_name="bedrock-runtime",
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error: {e}"
) from e

return self

Expand Down
61 changes: 12 additions & 49 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
Union,
)

import boto3
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -34,6 +33,7 @@
from langchain_aws.function_calling import _tools_in_params
from langchain_aws.utils import (
anthropic_tokens_supported,
create_aws_client,
enforce_stop_tokens,
get_num_tokens_anthropic,
get_token_ids_anthropic,
Expand Down Expand Up @@ -577,7 +577,7 @@ class BedrockBase(BaseLanguageModel, ABC):
client: Any = Field(default=None, exclude=True) #: :meta private:

region_name: Optional[str] = Field(default=None, alias="region")
"""The aws region e.g., `us-west-2`. Fallsback to AWS_REGION or AWS_DEFAULT_REGION
"""The aws region e.g., `us-west-2`. Falls back to AWS_REGION or AWS_DEFAULT_REGION
env variable or region specified in ~/.aws/config in case it is not provided here.
"""

Expand Down Expand Up @@ -742,54 +742,17 @@ def validate_environment(self) -> Self:
self.model_kwargs.pop("max_tokens")

# Skip creating new client if passed in constructor
if self.client is not None:
return self

creds = {
"aws_access_key_id": self.aws_access_key_id,
"aws_secret_access_key": self.aws_secret_access_key,
"aws_session_token": self.aws_session_token,
}
if creds["aws_access_key_id"] and creds["aws_secret_access_key"]:
session_params = {k: v.get_secret_value() for k, v in creds.items() if v}
elif any(creds.values()):
raise ValueError(
f"If any of aws_access_key_id, aws_secret_access_key, or "
f"aws_session_token are specified then both aws_access_key_id and "
f"aws_secret_access_key must be specified. Only received "
f"{(k for k, v in creds.items() if v)}."
if self.client is None:
self.client = create_aws_client(
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
endpoint_url=self.endpoint_url,
config=self.config,
service_name="bedrock-runtime",
)
elif self.credentials_profile_name is not None:
session_params = {"profile_name": self.credentials_profile_name}
else:
# use default credentials
session_params = {}

try:
session = boto3.Session(**session_params)

self.region_name = (
self.region_name
or os.getenv("AWS_REGION")
or os.getenv("AWS_DEFAULT_REGION")
or session.region_name
)

client_params = {
"endpoint_url": self.endpoint_url,
"config": self.config,
"region_name": self.region_name,
}
client_params = {k: v for k, v in client_params.items() if v}
self.client = session.client("bedrock-runtime", **client_params)
except ValueError as e:
raise ValueError(f"Error raised by bedrock service:\n\n{e}") from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error:\n\n{e}"
) from e

return self

Expand Down
Loading