Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AWS client generator utility #360

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 11 additions & 49 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import base64
import json
import logging
import os
import re
from operator import itemgetter
from typing import (
Expand All @@ -21,7 +20,6 @@
cast,
)

import boto3
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.language_models.chat_models import LangSmithParams
Expand Down Expand Up @@ -54,6 +52,7 @@
from typing_extensions import Self

from langchain_aws.function_calling import ToolsOutputParser
from langchain_aws.utils import get_aws_client

logger = logging.getLogger(__name__)
_BM = TypeVar("_BM", bound=BaseModel)
Expand Down Expand Up @@ -454,53 +453,16 @@ def validate_environment(self) -> Self:

# Skip creating new client if passed in constructor
if self.client is None:
creds = {
"aws_access_key_id": self.aws_access_key_id,
"aws_secret_access_key": self.aws_secret_access_key,
"aws_session_token": self.aws_session_token,
}
if creds["aws_access_key_id"] and creds["aws_secret_access_key"]:
session_params = {
k: v.get_secret_value() for k, v in creds.items() if v
}
elif any(creds.values()):
raise ValueError(
f"If any of aws_access_key_id, aws_secret_access_key, or "
f"aws_session_token are specified then both aws_access_key_id and "
f"aws_secret_access_key must be specified. Only received "
f"{(k for k, v in creds.items() if v)}."
)
elif self.credentials_profile_name is not None:
session_params = {"profile_name": self.credentials_profile_name}
else:
# use default credentials
session_params = {}

try:
session = boto3.Session(**session_params)

self.region_name = (
self.region_name
or os.getenv("AWS_REGION")
or os.getenv("AWS_DEFAULT_REGION")
or session.region_name
)

client_params = {
"endpoint_url": self.endpoint_url,
"config": self.config,
"region_name": self.region_name,
}
client_params = {k: v for k, v in client_params.items() if v}
self.client = session.client("bedrock-runtime", **client_params)
except ValueError as e:
raise ValueError(f"Error raised by bedrock service:\n\n{e}") from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error:\n\n{e}"
) from e
self.client = get_aws_client(
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
endpoint_url=self.endpoint_url,
config=self.config,
service_name="bedrock-runtime",
)

return self

Expand Down
90 changes: 53 additions & 37 deletions libs/aws/langchain_aws/embeddings/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from pydantic import BaseModel, ConfigDict, Field, model_validator
from langchain_core.utils import secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

from langchain_aws.utils import get_aws_client


class BedrockEmbeddings(BaseModel, Embeddings):
"""Bedrock embedding models.
Expand Down Expand Up @@ -45,7 +48,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
client: Any = Field(default=None, exclude=True) #: :meta private:
"""Bedrock client."""
region_name: Optional[str] = None
"""The aws region e.g., `us-west-2`. Fallsback to AWS_DEFAULT_REGION env variable
"""The aws region e.g., `us-west-2`. Falls back to AWS_REGION/AWS_DEFAULT_REGION env variable
or region specified in ~/.aws/config in case it is not provided here.
"""

Expand All @@ -57,6 +60,44 @@ class BedrockEmbeddings(BaseModel, Embeddings):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""

aws_access_key_id: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_ACCESS_KEY_ID", default=None)
)
"""AWS access key id.

If provided, aws_secret_access_key must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

If not provided, will be read from 'AWS_ACCESS_KEY_ID' environment variable.
"""

aws_secret_access_key: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SECRET_ACCESS_KEY", default=None)
)
"""AWS secret_access_key.

If provided, aws_access_key_id must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

If not provided, will be read from 'AWS_SECRET_ACCESS_KEY' environment variable.
"""

aws_session_token: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SESSION_TOKEN", default=None)
)
"""AWS session token.

If provided, aws_access_key_id and aws_secret_access_key must also be provided.
Not required unless using temporary credentials.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

If not provided, will be read from 'AWS_SESSION_TOKEN' environment variable.
"""

model_id: str = "amazon.titan-embed-text-v1"
"""Id of the model to call, e.g., amazon.titan-embed-text-v1, this is
equivalent to the modelId property in the list-foundation-models api"""
Expand Down Expand Up @@ -86,42 +127,17 @@ def provider(self) -> str:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that AWS credentials to and python package exists in environment."""

if self.client is not None:
return self

try:
import boto3

if self.credentials_profile_name is not None:
session = boto3.Session(profile_name=self.credentials_profile_name)
else:
# use default credentials
session = boto3.Session()

client_params = {}
if self.region_name:
client_params["region_name"] = self.region_name

if self.endpoint_url:
client_params["endpoint_url"] = self.endpoint_url

if self.config:
client_params["config"] = self.config

self.client = session.client("bedrock-runtime", **client_params)

except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
if self.client is None:
self.client = get_aws_client(
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
endpoint_url=self.endpoint_url,
config=self.config,
service_name="bedrock-runtime",
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error: {e}"
) from e

return self

Expand Down
62 changes: 12 additions & 50 deletions libs/aws/langchain_aws/llms/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
import logging
import os
import warnings
from abc import ABC
from typing import (
Expand All @@ -18,7 +17,6 @@
Union,
)

import boto3
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -35,6 +33,7 @@
from langchain_aws.utils import (
anthropic_tokens_supported,
enforce_stop_tokens,
get_aws_client,
get_num_tokens_anthropic,
get_token_ids_anthropic,
)
Expand Down Expand Up @@ -484,7 +483,7 @@ class BedrockBase(BaseLanguageModel, ABC):
client: Any = Field(default=None, exclude=True) #: :meta private:

region_name: Optional[str] = Field(default=None, alias="region")
"""The aws region e.g., `us-west-2`. Fallsback to AWS_REGION or AWS_DEFAULT_REGION
"""The aws region e.g., `us-west-2`. Falls back to AWS_REGION or AWS_DEFAULT_REGION
env variable or region specified in ~/.aws/config in case it is not provided here.
"""

Expand Down Expand Up @@ -649,54 +648,17 @@ def validate_environment(self) -> Self:
self.model_kwargs.pop("max_tokens")

# Skip creating new client if passed in constructor
if self.client is not None:
return self

creds = {
"aws_access_key_id": self.aws_access_key_id,
"aws_secret_access_key": self.aws_secret_access_key,
"aws_session_token": self.aws_session_token,
}
if creds["aws_access_key_id"] and creds["aws_secret_access_key"]:
session_params = {k: v.get_secret_value() for k, v in creds.items() if v}
elif any(creds.values()):
raise ValueError(
f"If any of aws_access_key_id, aws_secret_access_key, or "
f"aws_session_token are specified then both aws_access_key_id and "
f"aws_secret_access_key must be specified. Only received "
f"{(k for k, v in creds.items() if v)}."
if self.client is None:
self.client = get_aws_client(
region_name=self.region_name,
credentials_profile_name=self.credentials_profile_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
endpoint_url=self.endpoint_url,
config=self.config,
service_name="bedrock-runtime",
)
elif self.credentials_profile_name is not None:
session_params = {"profile_name": self.credentials_profile_name}
else:
# use default credentials
session_params = {}

try:
session = boto3.Session(**session_params)

self.region_name = (
self.region_name
or os.getenv("AWS_REGION")
or os.getenv("AWS_DEFAULT_REGION")
or session.region_name
)

client_params = {
"endpoint_url": self.endpoint_url,
"config": self.config,
"region_name": self.region_name,
}
client_params = {k: v for k, v in client_params.items() if v}
self.client = session.client("bedrock-runtime", **client_params)
except ValueError as e:
raise ValueError(f"Error raised by bedrock service:\n\n{e}") from e
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
f"profile name are valid. Bedrock error:\n\n{e}"
) from e

return self

Expand Down
Loading
Loading