Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support bedrock api in job and intention #525

Merged
merged 6 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions source/infrastructure/bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ export function getConfig(): SystemConfig {
chat: {
enabled: true,
bedrockRegion: "us-east-1",
apiInference: {
enabled: false,
apiInferenceProvider: "",
apiEndpoint: "",
apiKey: ""
},
bedrockAk: "",
bedrockSk: "",
useOpenSourceLLM: true,
Expand Down
96 changes: 0 additions & 96 deletions source/infrastructure/cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,8 @@ const embeddingModels = [
commitId: "43972580a35ceacacd31b95b9f430f695d07dde9",
dimensions: 768,
},
{
provider: "OpenAI API",
name: "text-embedding-3-small",
commitId: "",
dimensions: 1536,
},
{
provider: "OpenAI API",
name: "text-embedding-3-large",
commitId: "",
dimensions: 3072,
},
];

const apiInferenceProviders = [
{
provider: "bedrock",
name: "Bedrock API",
},
{
provider: "openai",
name: "OpenAI API",
},
];

const supportedRegions = Object.values(SupportedRegion) as string[];
const supportedBedrockRegions = Object.values(SupportedBedrockRegion) as string[];
Expand Down Expand Up @@ -135,10 +113,6 @@ async function getAwsAccountAndRegion() {
options.knowledgeBaseModelEcrImageTag = config.knowledgeBase.knowledgeBaseType.intelliAgentKb.knowledgeBaseModel.ecrImageTag;
options.enableChat = config.chat.enabled;
options.bedrockRegion = config.chat.bedrockRegion;
options.enableApiInference = config.chat.apiInference.enabled;
options.apiInferenceProvider = config.chat.apiInference.apiInferenceProvider;
options.apiEndpoint = config.chat.apiInference.apiEndpoint;
options.apiKey = config.chat.apiInference.apiKey;
options.enableConnect = config.chat.amazonConnect.enabled;
options.useOpenSourceLLM = config.chat.useOpenSourceLLM;
options.defaultEmbedding = config.model.embeddingsModels && config.model.embeddingsModels.length > 0
Expand Down Expand Up @@ -363,70 +337,6 @@ async function processCreateOptions(options: any): Promise<void> {
return (!(this as any).state.answers.enableChat);
},
},
{
type: "confirm",
name: "enableApiInference",
message: "Compared to local deployment, would you prefer to access the model via API calls?",
initial: options.enableApiInference ?? false,
skip(): boolean {
if (!(this as any).state.answers.enableChat) {
return true;
}
return false;
},
},
{
type: "select",
name: "apiInferenceProvider",
message: "Select an API inference provider, it is used for LLM invocation and generating embedding",
choices: apiInferenceProviders.map((m) => ({ name: m.name, value: m })),
initial: options.apiInferenceProvider,
skip(): boolean {
if (!(this as any).state.answers.enableApiInference ||
!(this as any).state.answers.enableChat) {
return true;
}
return false;
},
},
{
type: "input",
name: "apiEndpoint",
message: "API endpoint to invoke models, e.g. https://api.example.com/v1",
initial: options.apiEndpoint,
validate(apiEndpoint: string) {
return (this as any).skipped ||
RegExp(/^https?:\/\//).test(apiEndpoint)
? true
: "Enter a valid API endpoint, e.g. https://api.example.com/v1";
},
skip(): boolean {
if (!(this as any).state.answers.enableApiInference ||
!(this as any).state.answers.enableChat) {
return true;
}
return false;
},
},
{
type: "input",
name: "apiKey",
message: "When invoking Bedrock API or OpenAI API, you need to provide an API key, which should be stored in the Secrets Manager of your current account. Please enter the ARN of the API key, for example: arn:aws:secretsmanager:<region>:<account_id>:secret:SampleAPIKey",
initial: options.apiKey,
validate(apiKey: string) {
return (this as any).skipped ||
RegExp(/^arn:aws:secretsmanager:[a-z0-9-]+:[0-9]{12}:secret:[a-zA-Z0-9-_/]+$/).test(apiKey)
? true
: "Enter a valid Secrets Manager ARN (e.g., arn:aws:secretsmanager:region:123456789012:secret:mysecret)";
},
skip(): boolean {
if (!(this as any).state.answers.enableApiInference ||
!(this as any).state.answers.enableChat) {
return true;
}
return false;
},
},
{
type: "confirm",
name: "useOpenSourceLLM",
Expand Down Expand Up @@ -558,12 +468,6 @@ async function processCreateOptions(options: any): Promise<void> {
chat: {
enabled: answers.enableChat,
bedrockRegion: answers.bedrockRegion,
apiInference: {
enabled: answers.enableApiInference,
apiInferenceProvider: answers.apiInferenceProvider,
apiEndpoint: answers.apiEndpoint,
apiKey: answers.apiKey,
},
useOpenSourceLLM: answers.useOpenSourceLLM,
amazonConnect: {
enabled: answers.enableConnect,
Expand Down
4 changes: 0 additions & 4 deletions source/infrastructure/lib/api/intention-management.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ export class IntentionApi extends Construct {
KNOWLEDGE_BASE_ENABLED: this.config.knowledgeBase.enabled.toString(),
KNOWLEDGE_BASE_TYPE: JSON.stringify(this.config.knowledgeBase.knowledgeBaseType || {}),
BEDROCK_REGION: this.config.chat.bedrockRegion,
API_INFERENCE_ENABLED: this.config.chat.apiInference.enabled.toString(),
API_INFERENCE_PROVIDER: this.config.chat.apiInference.apiInferenceProvider,
API_ENDPOINT: this.config.chat.apiInference.apiEndpoint,
API_KEY_ARN: this.config.chat.apiInference.apiKey,
},
layers: [this.sharedLayer],
});
Expand Down
11 changes: 6 additions & 5 deletions source/infrastructure/lib/knowledge-base/knowledge-base-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
glueRole.addToPolicy(this.iamHelper.logStatement);
glueRole.addToPolicy(this.iamHelper.glueStatement);
glueRole.addToPolicy(this.dynamodbStatement);
glueRole.addToPolicy(this.iamHelper.dynamodbStatement);

// Create glue job to process files specified in s3 bucket and prefix
const glueJob = new glue.Job(this, "PythonShellJob", {
Expand Down Expand Up @@ -233,7 +234,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
"--PORTAL_BUCKET": this.uiPortalBucketName,
"--CHATBOT_TABLE": props.sharedConstructOutputs.chatbotTable.tableName,
"--additional-python-modules":
"langchain==0.3.7,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.35.98,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,nltk==3.9.1,pdfminer.six==20221105,smart-open==7.0.4,opensearch-py==2.2.0,lxml==5.2.2,pandas==2.1.2,openpyxl==3.1.5,xlrd==2.0.1,langchain_community==0.3.5,pillow==10.0.1",
"langchain==0.3.7,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.35.98,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,nltk==3.9.1,pdfminer.six==20221105,smart-open==7.0.4,opensearch-py==2.2.0,lxml==5.2.2,pandas==2.1.2,openpyxl==3.1.5,xlrd==2.0.1,langchain_community==0.3.5,pillow==10.0.1,tiktoken==0.8.0",
// Add multiple extra python files
"--extra-py-files": extraPythonFilesList
},
Expand Down Expand Up @@ -276,6 +277,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
"s3Prefix.$": "$.Payload.s3Prefix",
"qaEnhance.$": "$.Payload.qaEnhance",
"chatbotId.$": "$.Payload.chatbotId",
"groupName.$": "$.Payload.groupName",
"indexId.$": "$.Payload.indexId",
"embeddingModelType.$": "$.Payload.embeddingModelType",
"offline.$": "$.Payload.offline",
Expand Down Expand Up @@ -313,15 +315,13 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
"--QA_ENHANCEMENT.$": "$.qaEnhance",
"--REGION": process.env.CDK_DEFAULT_REGION || "-",
"--BEDROCK_REGION": props.config.chat.bedrockRegion,
"--API_INFERENCE_ENABLED": props.config.chat.apiInference.enabled.toString(),
"--API_INFERENCE_PROVIDER": props.config.chat.apiInference.apiInferenceProvider || "-",
"--API_ENDPOINT": props.config.chat.apiInference.apiEndpoint || "-",
"--API_KEY_ARN": props.config.chat.apiInference.apiKey || "-",
"--MODEL_TABLE": props.sharedConstructOutputs.modelTable.tableName,
"--RES_BUCKET": this.glueResultBucket.bucketName,
"--S3_BUCKET.$": "$.s3Bucket",
"--S3_PREFIX.$": "$.s3Prefix",
"--PORTAL_BUCKET": this.uiPortalBucketName,
"--CHATBOT_ID.$": "$.chatbotId",
"--GROUP_NAME.$": "$.groupName",
"--INDEX_ID.$": "$.indexId",
"--EMBEDDING_MODEL_TYPE.$": "$.embeddingModelType",
"--job-language": "python",
Expand All @@ -341,6 +341,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
"s3Bucket.$": "$.s3Bucket",
"s3Prefix.$": "$.s3Prefix",
"chatbotId.$": "$.chatbotId",
"groupName.$": "$.groupName",
"indexId.$": "$.indexId",
"embeddingModelType.$": "$.embeddingModelType",
"qaEnhance.$": "$.qaEnhance",
Expand Down
6 changes: 0 additions & 6 deletions source/infrastructure/lib/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ export interface SystemConfig {
chat: {
enabled: boolean;
bedrockRegion: string;
apiInference: {
enabled: boolean;
apiInferenceProvider: string;
apiEndpoint: string;
apiKey: string;
},
bedrockAk?: string;
bedrockSk?: string;
useOpenSourceLLM: boolean;
Expand Down
3 changes: 3 additions & 0 deletions source/lambda/etl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def lambda_handler(event, context):
prefix = event["s3Prefix"]
# fetch index from event with default value none
chatbot_id = event["chatbotId"]
group_name = event["groupName"]
index_id = event["indexId"]
embedding_model_type = event["embeddingModelType"]
index_type = event.get("indexType", "qd")
Expand Down Expand Up @@ -76,6 +77,7 @@ def lambda_handler(event, context):
"s3Prefix": prefix,
"fileCount": file_count,
"chatbotId": chatbot_id,
"groupName": group_name,
"indexId": index_id,
"embeddingModelType": embedding_model_type,
"qaEnhance": (
Expand All @@ -97,6 +99,7 @@ def lambda_handler(event, context):
"s3Prefix": prefix,
"fileCount": "1",
"chatbotId": chatbot_id,
"groupName": group_name,
"qaEnhance": (
event["qaEnhance"].lower() if "qaEnhance" in event else "false"
),
Expand Down
21 changes: 12 additions & 9 deletions source/lambda/intention/intention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@
chatbot_table_name = os.getenv("CHATBOT_TABLE_NAME", "chatbot")
index_table_name = os.getenv("INDEX_TABLE_NAME", "index")
model_table_name = os.getenv("MODEL_TABLE_NAME", "model")
api_inference_enabled = os.getenv["API_INFERENCE_ENABLED"]
api_inference_provider = os.getenv["API_INFERENCE_PROVIDER"]
api_inference_endpoint = os.getenv["API_ENDPOINT"]
api_key_arn = os.getenv["API_KEY_ARN"]
dynamodb_client = boto3.resource("dynamodb")
intention_table = dynamodb_client.Table(intention_table_name)
index_table = dynamodb_client.Table(index_table_name)
Expand Down Expand Up @@ -419,6 +415,8 @@ def __create_execution(event, context, email, group_name):
valid_qa_list,
bucket,
prefix,
group_name,
input_body.get("chatbotId")
)

return {
Expand Down Expand Up @@ -457,7 +455,13 @@ def convert_qa_list(qa_list: list, bucket: str, prefix: str) -> List[Document]:


def __save_2_aos(
modelId: str, index: str, qaListParam: list, bucket: str, prefix: str
modelId: str,
index: str,
qaListParam: list,
bucket: str,
prefix: str,
group_name: str,
chatbot_id: str
):
qaList = __deduplicate_by_key(qaListParam, "question")
if kb_enabled:
Expand All @@ -467,10 +471,9 @@ def __save_2_aos(
region_name=region,
bedrock_region=bedrock_region,
model_type=embedding_info.get("ModelType"),
api_inference_enabled=api_inference_enabled,
api_inference_provider=api_inference_provider,
api_inference_endpoint=api_inference_endpoint,
api_key_arn=api_key_arn,
group_name=group_name,
chatbot_id=chatbot_id,
model_table=model_table_name
)
docsearch = OpenSearchVectorSearch(
index_name=index,
Expand Down
4 changes: 3 additions & 1 deletion source/lambda/intention/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
openpyxl==3.1.3
openpyxl==3.1.3
openai==0.28.1
tiktoken==0.8.0
Binary file modified source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl
Binary file not shown.
61 changes: 52 additions & 9 deletions source/lambda/job/dep/llm_bot_dep/sm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,38 @@
)


def get_model_details(group_name: str, chatbot_id: str, table_name: str):
"""Get model details from DynamoDB table

Args:
group_name (str): The partition key (group name)
chatbot_id (str): Used to construct the model ID
table_name (str): DynamoDB table name

Returns:
dict: Model details from DynamoDB
"""
dynamodb = boto3.resource("dynamodb")
table = dynamodb.Table(table_name)
model_id = f"{chatbot_id}-embedding"

try:
response = table.get_item(
Key={
"groupName": group_name,
"modelId": model_id
}
)

if "Item" not in response:
raise Exception(f"No model found for group {group_name} and model ID {model_id}")

return response["Item"]
except Exception as e:
logger.error(f"Error retrieving model details: {str(e)}")
raise Exception(f"Failed to get model details: {str(e)}")


def get_secret_value(secret_arn: str):
"""Get secret value from secret manager

Expand Down Expand Up @@ -480,13 +512,23 @@ def SagemakerEndpointVectorOrCross(
)
return genericModel(prompt=prompt, stop=stop, **kwargs)


def getCustomEmbeddings(
endpoint_name: str, region_name: str, bedrock_region: str, model_type: str, api_inference_enabled: str = "false",
api_inference_provider: str = "Bedrock API", api_inference_endpoint: str = None, api_key_arn: str = None
endpoint_name: str,
region_name: str,
bedrock_region: str,
model_type: str,
group_name: str,
chatbot_id: str,
model_table: str
) -> SagemakerEndpointEmbeddings:
embeddings = None
if api_inference_enabled == "false":
model_details = get_model_details(group_name, chatbot_id, model_table)
model_provider = model_details["parameter"].get("ModelProvider", "Bedrock")
base_url = model_details["parameter"].get("BaseUrl", "")
api_key_arn = model_details["parameter"].get("ApiKeyArn", "")
logger.info(model_details)

if model_provider not in ["Bedrock API", "OpenAI API"]:
# Use local models
client = boto3.client("sagemaker-runtime", region_name=region_name)
bedrock_client = boto3.client("bedrock-runtime", region_name=bedrock_region)
Expand Down Expand Up @@ -521,18 +563,19 @@ def getCustomEmbeddings(
return embeddings

# API inference from Bedrock API or OpenAI API
if api_inference_provider == "Bedrock API":
if model_provider == "Bedrock API":
embeddings = OpenAIEmbeddings(
model=endpoint_name,
api_key=get_secret_value(api_key_arn),
base_url=api_inference_endpoint
base_url=base_url
)
elif api_inference_provider == "OpenAI API":
elif model_provider == "OpenAI API":
embeddings = OpenAIEmbeddings(
model=endpoint_name,
api_key=get_secret_value(api_key_arn)
api_key=get_secret_value(api_key_arn),
base_url=base_url
)
else:
raise ValueError(f"Unsupported API inference provider: {api_inference_provider}")
raise ValueError(f"Unsupported API inference provider: {model_provider}")

return embeddings
1 change: 1 addition & 0 deletions source/lambda/job/dep/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@
"pdfminer.six==20221105",
"smart-open==7.0.4",
"pillow==10.0.1",
"tiktoken==0.8.0"
],
)
Loading
Loading