Skip to content

Commit

Permalink
feat: Add cross-region inference support for Bedrock models (#535)
Browse files Browse the repository at this point in the history
- Bedrock and utils reformated
  • Loading branch information
chm10 committed Sep 17, 2024
1 parent 47e3eea commit 7102ab6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
33 changes: 25 additions & 8 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from app.repositories.models.conversation import MessageModel
from app.repositories.models.custom_bot import GenerationParamsModel
from app.routes.schemas.conversation import type_model_name
from app.utils import convert_dict_keys_to_camel_case, get_bedrock_client, is_region_supported_for_inference, ENABLE_BEDROCK_CROSS_REGION_INFERENCE
from app.utils import (
convert_dict_keys_to_camel_case,
get_bedrock_client,
is_region_supported_for_inference,
ENABLE_BEDROCK_CROSS_REGION_INFERENCE,
)
from typing_extensions import NotRequired, TypedDict, no_type_check

logger = logging.getLogger(__name__)
Expand All @@ -23,7 +28,9 @@
if ENABLE_MISTRAL
else DEFAULT_CLAUDE_GENERATION_CONFIG
)
ENABLE_BEDROCK_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true"
ENABLE_BEDROCK_CROSS_REGION_INFERENCE = (
os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true"
)

client = get_bedrock_client(BEDROCK_REGION)

Expand Down Expand Up @@ -257,13 +264,15 @@ def calculate_price(

return input_price * input_tokens / 1000.0 + output_price * output_tokens / 1000.0


CROSS_REGION_INFERENCE_MODELS = {
"claude-v3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-v3-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-v3-opus": "anthropic.claude-3-opus-20240229-v1:0",
"claude-v3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0",
}


def get_model_id(model: type_model_name) -> str:
# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
base_model_id = {
Expand All @@ -278,17 +287,25 @@ def get_model_id(model: type_model_name) -> str:
"mistral-large": "mistral.mistral-large-2402-v1:0",
}[model]

if (ENABLE_BEDROCK_CROSS_REGION_INFERENCE and
is_region_supported_for_inference(BEDROCK_REGION) and
model in CROSS_REGION_INFERENCE_MODELS):
logger.info(f"Using cross-region inference for model {model} in region {BEDROCK_REGION}")
if (
ENABLE_BEDROCK_CROSS_REGION_INFERENCE
and is_region_supported_for_inference(BEDROCK_REGION)
and model in CROSS_REGION_INFERENCE_MODELS
):
logger.info(
f"Using cross-region inference for model {model} in region {BEDROCK_REGION}"
)
return base_model_id
else:
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE:
if not is_region_supported_for_inference(BEDROCK_REGION):
logger.warning(f"Cross-region inference is enabled, but the region {BEDROCK_REGION} is not supported. Using local model.")
logger.warning(
f"Cross-region inference is enabled, but the region {BEDROCK_REGION} is not supported. Using local model."
)
elif model not in CROSS_REGION_INFERENCE_MODELS:
logger.warning(f"Cross-region inference is not available for model {model}. Using local model.")
logger.warning(
f"Cross-region inference is not available for model {model}. Using local model."
)
return f"{base_model_id}-local"


Expand Down
22 changes: 18 additions & 4 deletions backend/app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
"PUBLISH_API_CODEBUILD_PROJECT_NAME", ""
)
DB_SECRETS_ARN = os.environ.get("DB_SECRETS_ARN", "")
ENABLE_BEDROCK_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true"
ENABLE_BEDROCK_CROSS_REGION_INFERENCE = (
os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true"
)


def snake_to_camel(snake_str):
components = snake_str.split("_")
Expand All @@ -41,16 +44,27 @@ def is_running_on_lambda():


def is_region_supported_for_inference(region: str) -> bool:
supported_regions = ['us-east-1', 'us-west-2', 'eu-west-1', 'eu-west-3', 'eu-central-1'] # Add more as they become available
supported_regions = [
"us-east-1",
"us-west-2",
"eu-west-1",
"eu-west-3",
"eu-central-1",
] # Add more as they become available
return region in supported_regions


def get_bedrock_client(region=BEDROCK_REGION):
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE and is_region_supported_for_inference(region):
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE and is_region_supported_for_inference(
region
):
logger.info(f"Using cross-region Bedrock client for region {region}")
return boto3.client("bedrock-runtime", region_name=region)
else:
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE:
logger.warning(f"Cross-region inference is enabled, but the region {region} is not supported. Using default region.")
logger.warning(
f"Cross-region inference is enabled, but the region {region} is not supported. Using default region."
)
return boto3.client("bedrock-runtime", region_name=REGION)


Expand Down

0 comments on commit 7102ab6

Please sign in to comment.