-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add cross-region inference support for Bedrock models (#535) #536
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,12 @@ | |
from app.repositories.models.conversation import MessageModel | ||
from app.repositories.models.custom_bot import GenerationParamsModel | ||
from app.routes.schemas.conversation import type_model_name | ||
from app.utils import convert_dict_keys_to_camel_case, get_bedrock_client | ||
from app.utils import ( | ||
convert_dict_keys_to_camel_case, | ||
get_bedrock_client, | ||
is_region_supported_for_inference, | ||
ENABLE_BEDROCK_CROSS_REGION_INFERENCE, | ||
) | ||
from typing_extensions import NotRequired, TypedDict, no_type_check | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -23,8 +28,11 @@ | |
if ENABLE_MISTRAL | ||
else DEFAULT_CLAUDE_GENERATION_CONFIG | ||
) | ||
ENABLE_BEDROCK_CROSS_REGION_INFERENCE = ( | ||
os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true" | ||
) | ||
|
||
client = get_bedrock_client() | ||
client = get_bedrock_client(BEDROCK_REGION) | ||
|
||
|
||
class ConverseApiToolSpec(TypedDict): | ||
|
@@ -219,7 +227,7 @@ def compose_args_for_converse_api( | |
|
||
|
||
def call_converse_api(args: ConverseApiRequest) -> ConverseApiResponse: | ||
client = get_bedrock_client() | ||
client = get_bedrock_client(BEDROCK_REGION) | ||
messages = args["messages"] | ||
inference_config = args["inference_config"] | ||
additional_model_request_fields = args["additional_model_request_fields"] | ||
|
@@ -257,26 +265,48 @@ def calculate_price( | |
return input_price * input_tokens / 1000.0 + output_price * output_tokens / 1000.0 | ||
|
||
|
||
CROSS_REGION_INFERENCE_MODELS = { | ||
"claude-v3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", | ||
"claude-v3-haiku": "anthropic.claude-3-haiku-20240307-v1:0", | ||
"claude-v3-opus": "anthropic.claude-3-opus-20240229-v1:0", | ||
"claude-v3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", | ||
} | ||
|
||
|
||
def get_model_id(model: type_model_name) -> str: | ||
# Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html | ||
if model == "claude-v2": | ||
return "anthropic.claude-v2:1" | ||
elif model == "claude-instant-v1": | ||
return "anthropic.claude-instant-v1" | ||
elif model == "claude-v3-sonnet": | ||
return "anthropic.claude-3-sonnet-20240229-v1:0" | ||
elif model == "claude-v3-haiku": | ||
return "anthropic.claude-3-haiku-20240307-v1:0" | ||
elif model == "claude-v3-opus": | ||
return "anthropic.claude-3-opus-20240229-v1:0" | ||
elif model == "claude-v3.5-sonnet": | ||
return "anthropic.claude-3-5-sonnet-20240620-v1:0" | ||
elif model == "mistral-7b-instruct": | ||
return "mistral.mistral-7b-instruct-v0:2" | ||
elif model == "mixtral-8x7b-instruct": | ||
return "mistral.mixtral-8x7b-instruct-v0:1" | ||
elif model == "mistral-large": | ||
return "mistral.mistral-large-2402-v1:0" | ||
base_model_id = { | ||
"claude-v2": "anthropic.claude-v2:1", | ||
"claude-instant-v1": "anthropic.claude-instant-v1", | ||
"claude-v3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0", | ||
"claude-v3-haiku": "anthropic.claude-3-haiku-20240307-v1:0", | ||
"claude-v3-opus": "anthropic.claude-3-opus-20240229-v1:0", | ||
"claude-v3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", | ||
"mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2", | ||
"mixtral-8x7b-instruct": "mistral.mixtral-8x7b-instruct-v0:1", | ||
"mistral-large": "mistral.mistral-large-2402-v1:0", | ||
}[model] | ||
|
||
if ( | ||
ENABLE_BEDROCK_CROSS_REGION_INFERENCE | ||
and is_region_supported_for_inference(BEDROCK_REGION) | ||
and model in CROSS_REGION_INFERENCE_MODELS | ||
): | ||
logger.info( | ||
f"Using cross-region inference for model {model} in region {BEDROCK_REGION}" | ||
) | ||
return base_model_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to add prefix e.g. |
||
else: | ||
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE: | ||
if not is_region_supported_for_inference(BEDROCK_REGION): | ||
logger.warning( | ||
f"Cross-region inference is enabled, but the region {BEDROCK_REGION} is not supported. Using local model." | ||
) | ||
elif model not in CROSS_REGION_INFERENCE_MODELS: | ||
logger.warning( | ||
f"Cross-region inference is not available for model {model}. Using local model." | ||
) | ||
return f"{base_model_id}-local" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the postfix i.e. |
||
|
||
|
||
def calculate_query_embedding(question: str) -> list[float]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,9 @@ | |
"PUBLISH_API_CODEBUILD_PROJECT_NAME", "" | ||
) | ||
DB_SECRETS_ARN = os.environ.get("DB_SECRETS_ARN", "") | ||
ENABLE_BEDROCK_CROSS_REGION_INFERENCE = ( | ||
os.environ.get("ENABLE_BEDROCK_CROSS_REGION_INFERENCE", "false").lower() == "true" | ||
) | ||
|
||
|
||
def snake_to_camel(snake_str): | ||
|
@@ -40,9 +43,29 @@ def is_running_on_lambda(): | |
return "AWS_EXECUTION_ENV" in os.environ | ||
|
||
|
||
def is_region_supported_for_inference(region: str) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To emphasize for cross region inference, rename it like: |
||
supported_regions = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this variable is constant, so please replace with |
||
"us-east-1", | ||
"us-west-2", | ||
"eu-west-1", | ||
"eu-west-3", | ||
"eu-central-1", | ||
] # Add more as they become available | ||
return region in supported_regions | ||
|
||
|
||
def get_bedrock_client(region=BEDROCK_REGION): | ||
client = boto3.client("bedrock-runtime", region) | ||
return client | ||
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE and is_region_supported_for_inference( | ||
region | ||
): | ||
logger.info(f"Using cross-region Bedrock client for region {region}") | ||
return boto3.client("bedrock-runtime", region_name=region) | ||
else: | ||
if ENABLE_BEDROCK_CROSS_REGION_INFERENCE: | ||
logger.warning( | ||
f"Cross-region inference is enabled, but the region {region} is not supported. Using default region." | ||
) | ||
return boto3.client("bedrock-runtime", region_name=REGION) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
|
||
|
||
def get_bedrock_agent_client(region=REGION): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move it inside
get_model_id
since it's used only for the function