Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add bedrock api arn params #514

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/infrastructure/bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export function getConfig(): SystemConfig {
chat: {
enabled: true,
bedrockRegion: "us-east-1",
crossAccountBedrockKey: "",
bedrockAk: "",
bedrockSk: "",
useOpenSourceLLM: true,
Expand Down
17 changes: 17 additions & 0 deletions source/infrastructure/cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ async function getAwsAccountAndRegion() {
options.knowledgeBaseModelEcrImageTag = config.knowledgeBase.knowledgeBaseType.intelliAgentKb.knowledgeBaseModel.ecrImageTag;
options.enableChat = config.chat.enabled;
options.bedrockRegion = config.chat.bedrockRegion;
options.crossAccountBedrockKey = config.chat.crossAccountBedrockKey;
options.enableConnect = config.chat.amazonConnect.enabled;
options.useOpenSourceLLM = config.chat.useOpenSourceLLM;
options.defaultEmbedding = config.model.embeddingsModels && config.model.embeddingsModels.length > 0
Expand Down Expand Up @@ -336,6 +337,21 @@ async function processCreateOptions(options: any): Promise<void> {
return (!(this as any).state.answers.enableChat);
},
},
{
type: "input",
name: "crossAccountBedrockKey",
message: "If you don't need to use cross-account Bedrock functionality, you can press Enter to skip this step. When invoking Bedrock across accounts, you need to provide an API key, which should be stored in the Secrets Manager of your current account. Please enter the ARN of the API key, for example: arn:aws:secretsmanager:us-west-2:<aws_account_id>:secret:SampleAPIKey-AqZDIw",
initial: options.crossAccountBedrockKey,
validate(crossAccountBedrockKey: string) {
if ( crossAccountBedrockKey.includes('arn:aws:secretsmanager') || crossAccountBedrockKey.trim() === '') {
return true;
}
return "Enter a valid ARN or press Enter to skip it";
},
skip(): boolean {
return (!(this as any).state.answers.enableChat);
},
},
{
type: "confirm",
name: "useOpenSourceLLM",
Expand Down Expand Up @@ -484,6 +500,7 @@ async function processCreateOptions(options: any): Promise<void> {
chat: {
enabled: answers.enableChat,
bedrockRegion: answers.bedrockRegion,
crossAccountBedrockKey: answers.crossAccountBedrockKey,
useOpenSourceLLM: answers.useOpenSourceLLM,
amazonConnect: {
enabled: answers.enableConnect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
"--QA_ENHANCEMENT.$": "$.qaEnhance",
"--REGION": process.env.CDK_DEFAULT_REGION || "-",
"--BEDROCK_REGION": props.config.chat.bedrockRegion,
"--CROSS_ACCOUNT_BEDROCK_KEY": props.config.chat.crossAccountBedrockKey,
"--RES_BUCKET": this.glueResultBucket.bucketName,
"--S3_BUCKET.$": "$.s3Bucket",
"--S3_PREFIX.$": "$.s3Prefix",
Expand Down
1 change: 1 addition & 0 deletions source/infrastructure/lib/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export interface SystemConfig {
chat: {
enabled: boolean;
bedrockRegion: string;
crossAccountBedrockKey: string;
bedrockAk?: string;
bedrockSk?: string;
useOpenSourceLLM: boolean;
Expand Down
2 changes: 1 addition & 1 deletion source/lambda/job/build_whl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
cd ./dep
pip install setuptools wheel

python setup.py bdist_wheel
python3 setup.py bdist_wheel
Binary file modified source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl
Binary file not shown.
32 changes: 31 additions & 1 deletion source/lambda/job/dep/llm_bot_dep/sm_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import io
import json
import logging
import os
from typing import Any, Dict, Iterator, List, Mapping, Optional

import boto3
from botocore.exceptions import ClientError
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain_community.embeddings import (
BedrockEmbeddings,
Expand All @@ -22,6 +24,34 @@

logger = logging.getLogger()
logger.setLevel(logging.INFO)
region_name = os.environ["AWS_REGION"]
session = boto3.session.Session()
secret_manager_client = session.client(
service_name="secretsmanager", region_name=region_name
)


def get_secret_value(secret_arn: str):
"""Get secret value from secret manager

Args:
secret_arn (str): secret arn

Returns:
str: secret value
"""
try:
get_secret_value_response = secret_manager_client.get_secret_value(
SecretId=secret_arn
)
except ClientError as e:
raise Exception("Fail to retrieve the secret value: {}".format(e))
else:
if "SecretString" in get_secret_value_response:
secret = get_secret_value_response["SecretString"]
return secret
else:
raise Exception("Fail to retrieve the secret value")


class vectorContentHandler(EmbeddingsContentHandler):
Expand Down Expand Up @@ -452,7 +482,7 @@ def SagemakerEndpointVectorOrCross(


def getCustomEmbeddings(
endpoint_name: str, region_name: str, bedrock_region: str, model_type: str
endpoint_name: str, region_name: str, bedrock_region: str, model_type: str, bedrock_api_key_arn: str = None
) -> SagemakerEndpointEmbeddings:
client = boto3.client("sagemaker-runtime", region_name=region_name)
bedrock_client = boto3.client("bedrock-runtime", region_name=bedrock_region)
Expand Down
4 changes: 4 additions & 0 deletions source/lambda/job/glue-job-script.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"OPERATION_TYPE",
"PORTAL_BUCKET",
"BEDROCK_REGION",
"CROSS_ACCOUNT_BEDROCK_KEY",
],
)
except Exception as e:
Expand Down Expand Up @@ -86,6 +87,7 @@
args["RES_BUCKET"] = os.environ["RES_BUCKET"]
args["REGION"] = os.environ["REGION"]
args["BEDROCK_REGION"] = os.environ["BEDROCK_REGION"]
args["CROSS_ACCOUNT_BEDROCK_KEY"] = os.environ["CROSS_ACCOUNT_BEDROCK_KEY"]
args["PORTAL_BUCKET"] = os.environ.get("PORTAL_BUCKET", None)

from llm_bot_dep import sm_utils
Expand Down Expand Up @@ -115,6 +117,7 @@
qa_enhancement = args["QA_ENHANCEMENT"]
region = args["REGION"]
bedrock_region = args["BEDROCK_REGION"]
bedrock_api_key_arn = args["CROSS_ACCOUNT_BEDROCK_KEY"]
res_bucket = args["RES_BUCKET"]
s3_bucket = args["S3_BUCKET"]
s3_prefix = args["S3_PREFIX"]
Expand Down Expand Up @@ -737,6 +740,7 @@ def main():
region_name=region,
bedrock_region=bedrock_region,
model_type=embedding_model_type,
bedrock_api_key_arn=bedrock_api_key_arn,
)
aws_auth = get_aws_auth()
docsearch = OpenSearchVectorSearch(
Expand Down
Loading