Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add cross-region inference support for Bedrock models (#535) #536

Open
wants to merge 3 commits into
base: v1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from app.repositories.models.conversation import MessageModel
from app.repositories.models.custom_bot import GenerationParamsModel
from app.routes.schemas.conversation import type_model_name
from app.utils import convert_dict_keys_to_camel_case, get_bedrock_client
from app.utils import (
convert_dict_keys_to_camel_case,
get_bedrock_client,
is_region_supported_for_inference,
ENABLE_BEDROCK_CROSS_REGION_INFERENCE,
)
from typing_extensions import NotRequired, TypedDict, no_type_check

logger = logging.getLogger(__name__)
Expand All @@ -23,8 +28,11 @@
if ENABLE_MISTRAL
else DEFAULT_CLAUDE_GENERATION_CONFIG
)
ENABLE_BEDROCK_CROSS_REGION_INFERENCE = (
os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true"
)

client = get_bedrock_client()
client = get_bedrock_client(BEDROCK_REGION)


class ConverseApiToolSpec(TypedDict):
Expand Down Expand Up @@ -219,7 +227,7 @@ def compose_args_for_converse_api(


def call_converse_api(args: ConverseApiRequest) -> ConverseApiResponse:
client = get_bedrock_client()
client = get_bedrock_client(BEDROCK_REGION)
messages = args["messages"]
inference_config = args["inference_config"]
additional_model_request_fields = args["additional_model_request_fields"]
Expand Down Expand Up @@ -257,26 +265,48 @@ def calculate_price(
return input_price * input_tokens / 1000.0 + output_price * output_tokens / 1000.0


CROSS_REGION_INFERENCE_MODELS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move it inside get_model_id since it's used only for the function

"claude-v3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-v3-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-v3-opus": "anthropic.claude-3-opus-20240229-v1:0",
"claude-v3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0",
}


def get_model_id(model: type_model_name) -> str:
# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
if model == "claude-v2":
return "anthropic.claude-v2:1"
elif model == "claude-instant-v1":
return "anthropic.claude-instant-v1"
elif model == "claude-v3-sonnet":
return "anthropic.claude-3-sonnet-20240229-v1:0"
elif model == "claude-v3-haiku":
return "anthropic.claude-3-haiku-20240307-v1:0"
elif model == "claude-v3-opus":
return "anthropic.claude-3-opus-20240229-v1:0"
elif model == "claude-v3.5-sonnet":
return "anthropic.claude-3-5-sonnet-20240620-v1:0"
elif model == "mistral-7b-instruct":
return "mistral.mistral-7b-instruct-v0:2"
elif model == "mixtral-8x7b-instruct":
return "mistral.mixtral-8x7b-instruct-v0:1"
elif model == "mistral-large":
return "mistral.mistral-large-2402-v1:0"
base_model_id = {
"claude-v2": "anthropic.claude-v2:1",
"claude-instant-v1": "anthropic.claude-instant-v1",
"claude-v3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-v3-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-v3-opus": "anthropic.claude-3-opus-20240229-v1:0",
"claude-v3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2",
"mixtral-8x7b-instruct": "mistral.mixtral-8x7b-instruct-v0:1",
"mistral-large": "mistral.mistral-large-2402-v1:0",
}[model]

if (
ENABLE_BEDROCK_CROSS_REGION_INFERENCE
and is_region_supported_for_inference(BEDROCK_REGION)
and model in CROSS_REGION_INFERENCE_MODELS
):
logger.info(
f"Using cross-region inference for model {model} in region {BEDROCK_REGION}"
)
return base_model_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add prefix e.g. eu, us?

else:
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE:
if not is_region_supported_for_inference(BEDROCK_REGION):
logger.warning(
f"Cross-region inference is enabled, but the region {BEDROCK_REGION} is not supported. Using local model."
)
elif model not in CROSS_REGION_INFERENCE_MODELS:
logger.warning(
f"Cross-region inference is not available for model {model}. Using local model."
)
return f"{base_model_id}-local"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the postfix i.e. -local work?



def calculate_query_embedding(question: str) -> list[float]:
Expand Down
27 changes: 25 additions & 2 deletions backend/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"PUBLISH_API_CODEBUILD_PROJECT_NAME", ""
)
DB_SECRETS_ARN = os.environ.get("DB_SECRETS_ARN", "")
ENABLE_BEDROCK_CROSS_REGION_INFERENCE = (
os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true"
)


def snake_to_camel(snake_str):
Expand All @@ -40,9 +43,29 @@ def is_running_on_lambda():
return "AWS_EXECUTION_ENV" in os.environ


def is_region_supported_for_inference(region: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To emphasize for cross region inference, rename it like:
before: is_region_supported_for_inference.
after: is_region_supported_for_cross_inference

supported_regions = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this variable is constant, so please replace with SUPPORTED_REGIONS

"us-east-1",
"us-west-2",
"eu-west-1",
"eu-west-3",
"eu-central-1",
] # Add more as they become available
return region in supported_regions


def get_bedrock_client(region=BEDROCK_REGION):
client = boto3.client("bedrock-runtime", region)
return client
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE and is_region_supported_for_inference(
region
):
logger.info(f"Using cross-region Bedrock client for region {region}")
return boto3.client("bedrock-runtime", region_name=region)
else:
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE:
logger.warning(
f"Cross-region inference is enabled, but the region {region} is not supported. Using default region."
)
return boto3.client("bedrock-runtime", region_name=REGION)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use region i.e. BEDROCK_REGION as default client. REGION is used for services except for bedrock



def get_bedrock_agent_client(region=REGION):
Expand Down
1 change: 1 addition & 0 deletions cdk/cdk.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"@aws-cdk/core:includePrefixInUniqueNameGeneration": true,
"@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true,
"enableMistral": false,
"enableBedrockCrossRegionInference": false,
"bedrockRegion": "us-east-1",
"allowedIpV4AddressRanges": ["0.0.0.0/1", "128.0.0.0/1"],
"allowedIpV6AddressRanges": [
Expand Down
Loading