Skip to content

Commit

Permalink
feat: add bedrock api
Browse files Browse the repository at this point in the history
  • Loading branch information
Ning committed Jan 21, 2025
1 parent 5253b53 commit 374d65a
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 52 deletions.
4 changes: 2 additions & 2 deletions source/infrastructure/cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ const embeddingModels = [
const apiInferenceProviders = [
{
provider: "bedrock",
name: "Bedrock-API",
name: "Bedrock API",
},
{
provider: "openai",
name: "OpenAI-API",
name: "OpenAI API",
},
];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ export class KnowledgeBaseStack extends NestedStack implements KnowledgeBaseStac
"--QA_ENHANCEMENT.$": "$.qaEnhance",
"--REGION": process.env.CDK_DEFAULT_REGION || "-",
"--BEDROCK_REGION": props.config.chat.bedrockRegion,
"--API_INFERENCE_ENABLED": props.config.chat.apiInference.enabled,
"--API_INFERENCE_ENABLED": props.config.chat.apiInference.enabled.toString(),
"--API_INFERENCE_PROVIDER": props.config.chat.apiInference.apiInferenceProvider,
"--API_ENDPOINT": props.config.chat.apiInference.apiEndpoint,
"--API_KEY_ARN": props.config.chat.apiInference.apiKey,
Expand Down
Binary file modified source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl
Binary file not shown.
78 changes: 43 additions & 35 deletions source/lambda/job/dep/llm_bot_dep/sm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,48 +483,56 @@ def SagemakerEndpointVectorOrCross(


def getCustomEmbeddings(
endpoint_name: str, region_name: str, bedrock_region: str, model_type: str, bedrock_api_key_arn: str = None
endpoint_name: str, region_name: str, bedrock_region: str, model_type: str, api_inference_enabled: str = "false",
api_inference_provider: str = "Bedrock API", api_inference_endpoint: str = None, api_key_arn: str = None
) -> SagemakerEndpointEmbeddings:
client = boto3.client("sagemaker-runtime", region_name=region_name)
bedrock_client = boto3.client("bedrock-runtime", region_name=bedrock_region)
embeddings = None
if model_type == "bedrock":
content_handler = BedrockEmbeddings()
embeddings = BedrockEmbeddings(
client=bedrock_client,
model_id=endpoint_name,
normalize=True,
)
elif model_type == "brconnector":
if api_inference_enabled == "false":
# Use local models
client = boto3.client("sagemaker-runtime", region_name=region_name)
bedrock_client = boto3.client("bedrock-runtime", region_name=bedrock_region)
if model_type == "bedrock":
content_handler = BedrockEmbeddings()
embeddings = BedrockEmbeddings(
client=bedrock_client,
model_id=endpoint_name,
normalize=True,
)
elif model_type == "bce":
content_handler = vectorContentHandler()
embeddings = SagemakerEndpointEmbeddings(
client=client,
endpoint_name=endpoint_name,
content_handler=content_handler,
endpoint_kwargs={"TargetModel": "bce_embedding_model.tar.gz"},
)
# compatible with both m3 and bce.
else:
content_handler = m3ContentHandler()
model_kwargs = {}
model_kwargs["batch_size"] = 12
model_kwargs["max_length"] = 512
model_kwargs["return_type"] = "dense"
embeddings = SagemakerEndpointEmbeddings(
client=client,
endpoint_name=endpoint_name,
model_kwargs=model_kwargs,
content_handler=content_handler,
)
return embeddings

# API inference from Bedrock API or OpenAI API
if api_inference_provider == "Bedrock API":
embeddings = OpenAIEmbeddings(
model=endpoint_name,
api_key=get_secret_value(bedrock_api_key_arn)
api_key=get_secret_value(api_key_arn)
)
elif model_type == "openai":
elif api_inference_provider == "OpenAI API":
embeddings = OpenAIEmbeddings(
model=endpoint_name,
api_key=get_secret_value(openai_api_key_arn)
)

elif model_type == "bce":
content_handler = vectorContentHandler()
embeddings = SagemakerEndpointEmbeddings(
client=client,
endpoint_name=endpoint_name,
content_handler=content_handler,
endpoint_kwargs={"TargetModel": "bce_embedding_model.tar.gz"},
api_key=get_secret_value(api_key_arn)
)
# compatible with both m3 and bce.
else:
content_handler = m3ContentHandler()
model_kwargs = {}
model_kwargs["batch_size"] = 12
model_kwargs["max_length"] = 512
model_kwargs["return_type"] = "dense"
embeddings = SagemakerEndpointEmbeddings(
client=client,
endpoint_name=endpoint_name,
model_kwargs=model_kwargs,
content_handler=content_handler,
)
raise ValueError(f"Unsupported API inference provider: {api_inference_provider}")

return embeddings
15 changes: 12 additions & 3 deletions source/lambda/job/glue-job-script.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@
"OPERATION_TYPE",
"PORTAL_BUCKET",
"BEDROCK_REGION",
"CROSS_ACCOUNT_BEDROCK_KEY",
"API_INFERENCE_ENABLED",
"API_INFERENCE_PROVIDER",
"API_ENDPOINT",
"API_KEY_ARN",
],
)
except Exception as e:
Expand Down Expand Up @@ -117,7 +120,10 @@
qa_enhancement = args["QA_ENHANCEMENT"]
region = args["REGION"]
bedrock_region = args["BEDROCK_REGION"]
bedrock_api_key_arn = args["CROSS_ACCOUNT_BEDROCK_KEY"]
api_inference_enabled = args["API_INFERENCE_ENABLED"]
api_inference_provider = args["API_INFERENCE_PROVIDER"]
api_inference_endpoint = args["API_ENDPOINT"]
api_key_arn = args["API_KEY_ARN"]
res_bucket = args["RES_BUCKET"]
s3_bucket = args["S3_BUCKET"]
s3_prefix = args["S3_PREFIX"]
Expand Down Expand Up @@ -740,7 +746,10 @@ def main():
region_name=region,
bedrock_region=bedrock_region,
model_type=embedding_model_type,
bedrock_api_key_arn=bedrock_api_key_arn,
api_inference_enabled=api_inference_enabled,
api_inference_provider=api_inference_provider,
api_inference_endpoint=api_inference_endpoint,
api_key_arn=api_key_arn,
)
aws_auth = get_aws_auth()
docsearch = OpenSearchVectorSearch(
Expand Down
6 changes: 3 additions & 3 deletions source/lambda/online/common_logic/common_utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ class ToolRuningMode(ConstantBase):

class ModelProvider(ConstantBase):
DMAA = "dmaa"
BEDROCK = "bedrock"
BRCONNECTOR_BEDROCK = "brconnector"
OPENAI = "openai"
BEDROCK = "Bedrock"
BRCONNECTOR_BEDROCK = "Bedrock API"
OPENAI = "OpenAI API"


class LLMModelType(ConstantBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,11 @@ class LLMConfig(AllowBaseModel):
api_key_arn: Union[str,None] = None
api_key: Union[str,None] = None
model_kwargs: dict = {"temperature": 0.01, "max_tokens": 4096}
api_key: Union[str,None] = None

def model_post_init(self, __context: Any) -> None:
if self.br_api_key is None and self.br_api_key_arn is not None and self.base_url is not None:
self.br_api_key = get_secret_value(self.br_api_key_arn)
if self.openai_api_key is None and self.openai_api_key_arn is not None and self.base_url is not None:
self.openai_api_key = get_secret_value(self.openai_api_key_arn)

if self.provider in [ModelProvider.BRCONNECTOR_BEDROCK, ModelProvider.OPENAI] and \
self.api_key_arn is not None:
self.api_key = get_secret_value(self.api_key_arn)

class QueryRewriteConfig(LLMConfig):
rewrite_first_message: bool = False
Expand Down
4 changes: 2 additions & 2 deletions source/portal/src/pages/chatbot/ChatBot.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ const ChatBot: React.FC<ChatBotProps> = (props: ChatBotProps) => {
default_llm_config: {
model_id: modelOption,
endpoint_name: modelOption === 'qwen2-72B-instruct' ? endPoint.trim() : '',
provider: modelType.value,
base_url: (modelType.value === 'Bedrock API' || 'OpenAI API') ? apiEndpoint.trim() : '',
br_api_key_arn: modelType.value === 'Bedrock API' ? apiKeyArn.trim() : '',
openai_api_key_arn: modelType.value === 'OpenAI API' ? apiKeyArn.trim() : '',
api_key_arn: (modelType.value === 'Bedrock API' || 'OpenAI API') ? apiKeyArn.trim() : '',
model_kwargs: {
temperature: parseFloat(temperature),
max_tokens: parseInt(maxToken),
Expand Down

0 comments on commit 374d65a

Please sign in to comment.