diff --git a/docs/draw/online_function_modules.drawio b/docs/draw/online_function_modules.drawio
new file mode 100644
index 00000000..2715f2cf
--- /dev/null
+++ b/docs/draw/online_function_modules.drawio
@@ -0,0 +1,61 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/source/infrastructure/bin/main.ts b/source/infrastructure/bin/main.ts
index 88b67c85..e9d2d1bd 100644
--- a/source/infrastructure/bin/main.ts
+++ b/source/infrastructure/bin/main.ts
@@ -84,6 +84,7 @@ export class RootStack extends Stack {
_embeddingEndPoints:_LLMStack._embeddingEndPoints || '',
_instructEndPoint:_LLMStack._instructEndPoint || '',
_chatSessionTable: _DynamoDBStack._chatSessionTable,
+ _workspaceTable: _EtlStack._workspaceTableName,
_sfnOutput: _EtlStack._sfnOutput,
_OpenSearchIndex: _CdkParameters._OpenSearchIndex.valueAsString,
_OpenSearchIndexDict: _CdkParameters._OpenSearchIndexDict.valueAsString,
@@ -113,7 +114,7 @@ export class RootStack extends Stack {
new CfnOutput(this, 'Cross Model Endpoint', {value:_LLMStack._rerankEndPoint || 'No Cross Endpoint Created'});
new CfnOutput(this, 'Embedding Model Endpoint', {value:_LLMStack._embeddingEndPoints[0] || 'No Embedding Endpoint Created'});
new CfnOutput(this, 'Instruct Model Endpoint', {value:_LLMStack._instructEndPoint || 'No Instruct Endpoint Created'});
- new CfnOutput(this, 'Processed Object Table', {value:_EtlStack._processedObjectsTable});
+ new CfnOutput(this, 'Processed Object Table', {value:_EtlStack._processedObjectsTableName});
new CfnOutput(this, 'Chunk Bucket', {value:_EtlStack._resBucketName});
new CfnOutput(this, '_aosIndexDict', {value:_CdkParameters._OpenSearchIndexDict.valueAsString});
}
diff --git a/source/infrastructure/lib/api/api-stack.ts b/source/infrastructure/lib/api/api-stack.ts
index b4e82c0f..98b8fa60 100644
--- a/source/infrastructure/lib/api/api-stack.ts
+++ b/source/infrastructure/lib/api/api-stack.ts
@@ -26,6 +26,7 @@ interface apiStackProps extends StackProps {
_embeddingEndPoints: string[];
_instructEndPoint: string;
_chatSessionTable: string;
+ _workspaceTable: string;
// type of StepFunctions
_sfnOutput: sfn.StateMachine;
_OpenSearchIndex: string;
@@ -51,6 +52,7 @@ export class LLMApiStack extends NestedStack {
const _aosIndex = props._OpenSearchIndex
const _aosIndexDict = props._OpenSearchIndexDict
const _chatSessionTable = props._chatSessionTable
+ const _workspaceTable = props._workspaceTable
const _jobQueueArn = props._jobQueueArn
const _jobDefinitionArn = props._jobDefinitionArn
const _etlEndpoint = props._etlEndpoint
@@ -291,7 +293,7 @@ export class LLMApiStack extends NestedStack {
const lambdaDdbIntegration = new apigw.LambdaIntegration(lambdaDdb, { proxy: true, });
// All AOS wrapper should be within such lambda
- const apiResourceDdb = api.root.addResource('ddb');
+ const apiResourceDdb = api.root.addResource('feedback');
apiResourceDdb.addMethod('POST', lambdaDdbIntegration);
const apiResourceStepFunction = api.root.addResource('etl');
@@ -323,7 +325,8 @@ export class LLMApiStack extends NestedStack {
architecture: Architecture.X86_64,
environment: {
aos_endpoint: _domainEndpoint,
- llm_endpoint: props._instructEndPoint,
+ llm_model_endpoint_name: props._instructEndPoint,
+ llm_model_id: "internlm2-chat-7b",
embedding_endpoint: props._embeddingEndPoints[0],
zh_embedding_endpoint: props._embeddingEndPoints[0],
en_embedding_endpoint: props._embeddingEndPoints[1],
@@ -331,6 +334,7 @@ export class LLMApiStack extends NestedStack {
aos_index: _aosIndex,
aos_index_dict: _aosIndexDict,
chat_session_table: _chatSessionTable,
+ workspace_table: _workspaceTable,
},
layers: [_ApiLambdaExecutorLayer]
});
diff --git a/source/infrastructure/lib/etl/etl-stack.ts b/source/infrastructure/lib/etl/etl-stack.ts
index 7dac3703..d6ec95fb 100644
--- a/source/infrastructure/lib/etl/etl-stack.ts
+++ b/source/infrastructure/lib/etl/etl-stack.ts
@@ -39,7 +39,8 @@ export class EtlStack extends NestedStack {
_sfnOutput;
_jobName;
_jobArn;
- _processedObjectsTable;
+ _processedObjectsTableName;
+ _workspaceTableName;
_etlEndpoint: string;
_resBucketName: string;
@@ -121,6 +122,33 @@ export class EtlStack extends NestedStack {
// No sort key for this index
});
+ const workspaceTable = new dynamodb.Table(this, "WorkspaceTable", {
+ partitionKey: {
+ name: "workspace_id",
+ type: dynamodb.AttributeType.STRING,
+ },
+ sortKey: {
+ name: "object_type",
+ type: dynamodb.AttributeType.STRING,
+ },
+ billingMode: dynamodb.BillingMode.PAY_PER_REQUEST,
+ encryption: dynamodb.TableEncryption.AWS_MANAGED,
+ pointInTimeRecovery: true,
+ removalPolicy: RemovalPolicy.DESTROY,
+ });
+
+ workspaceTable.addGlobalSecondaryIndex({
+ indexName: "by_object_type_idx",
+ partitionKey: {
+ name: "object_type",
+ type: dynamodb.AttributeType.STRING,
+ },
+ sortKey: {
+ name: "created_at",
+ type: dynamodb.AttributeType.STRING,
+ },
+ });
+
const _S3Bucket = new s3.Bucket(this, 'llm-bot-glue-res-bucket', {
// bucketName: `llm-bot-glue-lib-${Aws.ACCOUNT_ID}-${Aws.REGION}`,
blockPublicAccess: s3.BlockPublicAccess.BLOCK_ALL,
@@ -181,19 +209,20 @@ export class EtlStack extends NestedStack {
'--QA_ENHANCEMENT.$': sfn.JsonPath.stringAt('$.qaEnhance'),
'--AOS_ENDPOINT': props._domainEndpoint,
'--REGION': props._region,
- '--EMBEDDING_MODEL_ENDPOINT': props._embeddingEndpoint.join(','),
+ '--EMBEDDING_MODEL_ENDPOINT': props._embeddingEndpoint[0],
'--ETL_MODEL_ENDPOINT': this._etlEndpoint,
'--DOC_INDEX_TABLE': props._OpenSearchIndex,
'--RES_BUCKET': _S3Bucket.bucketName,
'--ProcessedObjectsTable': table.tableName,
- '--additional-python-modules': 'langchain==0.0.312,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.28.84,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,nltk==3.8.1,pdfminer.six==20221105',
+ '--WORKSPACE_TABLE': workspaceTable.tableName,
+ '--WORKSPACE_ID.$': sfn.JsonPath.stringAt('$.workspaceId'),
+ '--additional-python-modules': 'langchain==0.1.0,beautifulsoup4==4.12.2,requests-aws4auth==1.2.3,boto3==1.28.84,openai==0.28.1,pyOpenSSL==23.3.0,tenacity==8.2.3,markdownify==0.11.6,mammoth==1.6.0,chardet==5.2.0,python-docx==1.1.0,nltk==3.8.1,pdfminer.six==20221105',
'--python-modules-installer-option': BuildConfig.JOB_PIP_OPTION,
// add multiple extra python files
'--extra-py-files': extraPythonFilesList,
'--CONTENT_TYPE': 'ug',
'--EMBEDDING_LANG': 'zh,zh,en,en',
'--EMBEDDING_TYPE': 'similarity,relevance,similarity,relevance',
- '--AOS_INDEX.$': sfn.JsonPath.stringAt('$.aosIndex'),
}
});
@@ -238,7 +267,7 @@ export class EtlStack extends NestedStack {
's3Prefix.$': '$.Payload.s3Prefix',
'qaEnhance.$': '$.Payload.qaEnhance',
'offline.$': '$.Payload.offline',
- 'aosIndex.$': '$.Payload.aosIndex',
+ 'workspaceId.$': '$.Payload.workspaceId',
}
},
// we need the original input
@@ -259,16 +288,16 @@ export class EtlStack extends NestedStack {
'--S3_BUCKET.$': '$.s3Bucket',
'--S3_PREFIX.$': '$.s3Prefix',
'--AOS_ENDPOINT': props._domainEndpoint,
- '--EMBEDDING_MODEL_ENDPOINT': props._embeddingEndpoint.join(','),
+ '--EMBEDDING_MODEL_ENDPOINT': props._embeddingEndpoint[0],
'--ETL_MODEL_ENDPOINT': this._etlEndpoint,
'--DOC_INDEX_TABLE': props._OpenSearchIndex,
'--REGION': props._region,
'--RES_BUCKET': _S3Bucket.bucketName,
'--OFFLINE': 'true',
'--QA_ENHANCEMENT.$': '$.qaEnhance',
+ '--WORKSPACE_ID.$': '$.workspaceId',
// Convert the numeric index to a string
'--BATCH_INDICE.$': 'States.Format(\'{}\', $.batchIndices)',
- '--AOS_INDEX.$': '$.aosIndex',
'--ProcessedObjectsTable': table.tableName,
'--CONTENT_TYPE': 'ug',
'--EMBEDDING_LANG': 'zh,zh,en,en',
@@ -289,7 +318,7 @@ export class EtlStack extends NestedStack {
's3Bucket.$': '$.s3Bucket',
's3Prefix.$': '$.s3Prefix',
'qaEnhance.$': '$.qaEnhance',
- 'aosIndex.$': '$.aosIndex',
+ 'workspaceId.$': '$.workspaceId',
// 'index' is a special variable within the Map state that represents the current index
'batchIndices.$': '$$.Map.Item.Index' // Add this if you need to know the index of the current item in the map state
},
@@ -308,17 +337,17 @@ export class EtlStack extends NestedStack {
'--S3_BUCKET.$': '$.s3Bucket',
'--S3_PREFIX.$': '$.s3Prefix',
'--AOS_ENDPOINT': props._domainEndpoint,
- '--EMBEDDING_MODEL_ENDPOINT': props._embeddingEndpoint.join(','),
+ '--EMBEDDING_MODEL_ENDPOINT': props._embeddingEndpoint[0],
'--ETL_MODEL_ENDPOINT': this._etlEndpoint,
'--DOC_INDEX_TABLE': props._OpenSearchIndex,
'--REGION': props._region,
'--RES_BUCKET': _S3Bucket.bucketName,
'--OFFLINE': 'false',
'--QA_ENHANCEMENT.$': '$.qaEnhance',
+ '--WORKSPACE_ID.$': '$.workspaceId',
// set the batch indice to 0 since we are running online
'--BATCH_INDICE': '0',
'--ProcessedObjectsTable': table.tableName,
- '--AOS_INDEX.$': '$.aosIndex',
}),
});
@@ -348,7 +377,8 @@ export class EtlStack extends NestedStack {
this._sfnOutput = sfnStateMachine;
this._jobName = glueJob.jobName;
this._jobArn = glueJob.jobArn;
- this._processedObjectsTable = table.tableName
+ this._processedObjectsTableName = table.tableName;
+ this._workspaceTableName = workspaceTable.tableName;
this._resBucketName = _S3Bucket.bucketName
}
}
\ No newline at end of file
diff --git a/source/infrastructure/lib/model/llm-stack.ts b/source/infrastructure/lib/model/llm-stack.ts
index e2b44e97..b2fc73e8 100644
--- a/source/infrastructure/lib/model/llm-stack.ts
+++ b/source/infrastructure/lib/model/llm-stack.ts
@@ -200,7 +200,7 @@ export class LLMStack extends NestedStack {
tags: instruct_tag_array,
});
- this._instructEndPoint = InstructEndpoint.endpointName as string;
+ this._instructEndPoint = InstructEndpointName;
}
diff --git a/source/lambda/custom/index.js b/source/lambda/custom/index.js
deleted file mode 100644
index 7bfa61b7..00000000
--- a/source/lambda/custom/index.js
+++ /dev/null
@@ -1,31 +0,0 @@
-const AWS = require('aws-sdk');
-const fs = require('fs');
-const tar = require('tar');
-
-// obsolete for now, use script to upload model.tar.gz to s3 instead
-exports.handler = async (event) => {
- const s3 = new AWS.S3();
- const bucketName = process.env.BUCKET_NAME;
- const key = 'model.tar.gz';
-
- // Create files A and B
- fs.writeFileSync('/tmp/fileA.txt', 'Content of file A');
- fs.writeFileSync('/tmp/fileB.txt', 'Content of file B');
-
- // Package the files into model.tar.gz
- await tar.c({
- gzip: true,
- file: '/tmp/model.tar.gz',
- cwd: '/tmp',
- }, ['fileA.txt', 'fileB.txt']);
-
- // Upload model.tar.gz to the S3 bucket
- const fileStream = fs.createReadStream('/tmp/model.tar.gz');
- await s3.upload({
- Bucket: bucketName,
- Key: key,
- Body: fileStream,
- }).promise();
-
- console.log(`Uploaded model.tar.gz to s3://${bucketName}/${key}`);
-};
diff --git a/source/lambda/ddb/rating.py b/source/lambda/ddb/rating.py
index 3a2f78ea..070a93ca 100644
--- a/source/lambda/ddb/rating.py
+++ b/source/lambda/ddb/rating.py
@@ -74,6 +74,34 @@ def get_session(table, session_id, user_id):
return response.get("Item", {})
+def get_feedback(table, session_id, user_id, message_id):
+ session = get_session(table, session_id, user_id)
+ messages = session.get("History", [])
+ feedback = None
+
+ if not messages:
+ return {
+ "added": False,
+ "error": "Failed to add feedback. No messages found in session.",
+ }
+ elif not message_id:
+ return {
+ "added": False,
+ "error": "Failed to add feedback. Please specify the message_id in the request to add feedback.",
+ }
+
+ for message in messages:
+
+ ddb_message_id = (
+ message.get("data", {}).get("additional_kwargs", {}).get("message_id", "")
+ )
+ if message_id == ddb_message_id:
+ feedback = message["data"]["additional_kwargs"].get("feedback", {})
+ return feedback
+
+ return feedback
+
+
# SESSIONS_BY_USER_ID_INDEX_NAME = "byUserId"
def list_sessions_by_user_id(table, user_id, SESSIONS_BY_USER_ID_INDEX_NAME):
response = {}
@@ -196,13 +224,14 @@ def add_feedback(table, session_id, user_id, message_id, feedback) -> None:
break
try:
- table.put_item(
- Item={
+ table.update_item(
+ Key={
"SessionId": session_id,
"UserId": user_id,
- "StartTime": start_time,
- "History": messages,
- }
+ },
+ UpdateExpression="SET History = :msg",
+ ExpressionAttributeValues={":msg": messages},
+ ReturnValues="ALL_NEW",
)
response = {"added": True}
@@ -270,6 +299,9 @@ def lambda_handler(event, context):
operations_mapping = {
"POST": {
"get_session": lambda: get_session(session_table, session_id, user_id),
+ "get_feedback": lambda: get_feedback(
+ session_table, session_id, user_id, message_id
+ ),
"list_sessions_by_user_id": lambda: list_sessions_by_user_id(
session_table, user_id, SESSIONS_BY_USER_ID_INDEX_NAME
),
diff --git a/source/lambda/etl/main.py b/source/lambda/etl/main.py
index 25ad3375..cbf40938 100644
--- a/source/lambda/etl/main.py
+++ b/source/lambda/etl/main.py
@@ -1,46 +1,49 @@
import json
-import boto3
import logging
+import boto3
+
logger = logging.getLogger()
logger.setLevel(logging.INFO)
-s3_client = boto3.client('s3')
+s3_client = boto3.client("s3")
+
# Offline lambda function to count the number of files in the S3 bucket
def lambda_handler(event, context):
logger.info(f"event:{event}")
# Retrieve bucket name and prefix from the event object passed by Step Function
- bucket_name = event['s3Bucket']
- prefix = event['s3Prefix']
+ bucket_name = event["s3Bucket"]
+ prefix = event["s3Prefix"]
# fetch index from event with default value none
- aos_index = event['aosIndex'] if 'aosIndex' in event else None
-
+ workspace_id = event["workspaceId"]
+
# Initialize the file count
file_count = 0
-
+
# Paginate through the list of objects in the bucket with the specified prefix
- paginator = s3_client.get_paginator('list_objects_v2')
+ paginator = s3_client.get_paginator("list_objects_v2")
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
-
+
# Count the files, note skip the prefix with slash, which is the folder name
for page in page_iterator:
- for obj in page.get('Contents', []):
- if obj['Key'].endswith('/'):
+ for obj in page.get("Contents", []):
+ if obj["Key"].endswith("/"):
continue
file_count += 1
-
+
# convert the fileCount into an array of numbers "fileIndices": [0, 1, 2, ..., 10], an array from 0 to fileCount-1
- batch_indices = list(range(file_count))
+ indice_count = file_count // 100 + 1
+ batch_indices = list(range(indice_count))
# This response should match the expected input schema of the downstream tasks in the Step Functions workflow
return {
- 'fileCount': file_count,
- 's3Bucket': bucket_name,
- 's3Prefix': prefix,
- 'qaEnhance': event['qaEnhance'].lower() if 'qaEnhance' in event else 'false',
+ "fileCount": file_count,
+ "s3Bucket": bucket_name,
+ "s3Prefix": prefix,
+ "qaEnhance": event["qaEnhance"].lower() if "qaEnhance" in event else "false",
# boolean value to indicate if the lambda function is running in offline mode
- 'offline': event['offline'].lower(),
- 'batchIndices': batch_indices,
- 'aosIndex': aos_index
+ "offline": event["offline"].lower(),
+ "batchIndices": batch_indices,
+ "workspaceId": workspace_id,
}
diff --git a/source/lambda/etl/sfn_handler.py b/source/lambda/etl/sfn_handler.py
index 4820f9f1..e9012bd7 100644
--- a/source/lambda/etl/sfn_handler.py
+++ b/source/lambda/etl/sfn_handler.py
@@ -21,6 +21,7 @@ def handler(event, context):
"s3Prefix": key,
"offline": "false",
"qaEnhance": "false",
+ "worksapceId": "default-workspace-id",
}
)
else:
diff --git a/source/lambda/executor/main.py b/source/lambda/executor/main.py
index 60bca652..d542765e 100644
--- a/source/lambda/executor/main.py
+++ b/source/lambda/executor/main.py
@@ -1,76 +1,33 @@
-import copy
import json
-import logging
import os
-os.environ["PYTHONUNBUFFERED"]="1"
-import sys
+
+os.environ["PYTHONUNBUFFERED"] = "1"
import time
-import traceback
import uuid
-import asyncio
-import math
-
import boto3
-from langchain.callbacks.base import BaseCallbackHandler
-from langchain.globals import set_verbose
-# from langchain.llms import OpenAI
-from langchain.output_parsers import PydanticOutputParser
-from langchain.prompts import ChatPromptTemplate
-from langchain.pydantic_v1 import BaseModel, Field, validator
-from langchain.retrievers import ContextualCompressionRetriever
-from langchain.retrievers.document_compressors import CohereRerank
-
-# from llm_utils import CustomLLM
-from langchain.retrievers.merger_retriever import MergerRetriever
-from langchain.retrievers.web_research import WebResearchRetriever
-from langchain.schema.runnable import (
- RunnableBranch,
- RunnableLambda,
- RunnableParallel,
- RunnablePassthrough,
+import logging
+# from langchain.retrievers.multi_query import MultiQueryRetriever
+from utils.logger_utils import logger
+from utils.constant import Type
+from utils.ddb_utils import DynamoDBChatMessageHistory
+from utils.executor_entries import (
+ get_retriever_response,
+ main_chain_entry,
+ main_qd_retriever_entry,
+ main_qq_retriever_entry,
+ market_chain_entry,
+ market_chain_entry_core,
+ market_conversation_summary_entry,
+ market_chain_knowledge_entry,
+ market_chain_knowledge_entry_langgraph
)
-from langchain.schema.messages import (
- HumanMessage,AIMessage,SystemMessage
-)
-# from langchain.memory import ConversationSummaryMemory, ChatMessageHistory
-# from langchain.utilities import GoogleSearchAPIWrapper
-from dateutil import parser
-from utils.reranker import BGEReranker, MergeReranker
-from utils.retriever import (
- QueryDocumentRetriever,
- QueryQuestionRetriever,
- index_results_format,
-)
-from langchain.retrievers.multi_query import MultiQueryRetriever
-from utils.logger_utils import logger,opensearch_logger,boto3_logger
-
-
-from utils.aos_utils import LLMBotOpenSearchClient
-from utils.constant import IntentType, Type
-from utils.ddb_utils import DynamoDBChatMessageHistory,filter_chat_history_by_time
-from utils.intent_utils import auto_intention_recoginition_chain
-# from langchain_utils import create_identity_lambda
-
-# from llm_utils import generate as llm_generate
-from utils.llm_utils import get_llm_chain,get_llm_model
-from utils.llmbot_utils import (
- # QueryType,
- combine_recalls,
- # concat_recall_knowledge,
- # process_input_messages,
-)
-from utils.time_utils import timeit
-from utils.preprocess_utils import run_preprocess
+# from langchain.retrievers.multi_query import MultiQueryRetriever
+from utils.logger_utils import logger
from utils.response_utils import process_response
-from utils.sm_utils import SagemakerEndpointVectorOrCross
-from utils.constant import Type,IntentType
-from utils.intent_utils import auto_intention_recoginition_chain
-from utils.langchain_utils import add_key_to_debug,chain_logger
-from utils.query_process_utils import get_query_process_chain
-import utils.parse_config as parse_config
from utils.serialization_utils import JSONEncoder
-from utils.constant import MKT_CONVERSATION_SUMMARY_TYPE
+
+# from utils.constant import MKT_CONVERSATION_SUMMARY_TYPE
region = os.environ["AWS_REGION"]
embedding_endpoint = os.environ.get("embedding_endpoint", "")
@@ -85,11 +42,10 @@
llm_endpoint = os.environ.get("llm_endpoint", "")
chat_session_table = os.environ.get("chat_session_table", "")
websocket_url = os.environ.get("websocket_url", "")
-sm_client = boto3.client("sagemaker-runtime")
-aos_client = LLMBotOpenSearchClient(aos_endpoint)
+# sm_client = boto3.client("sagemaker-runtime")
+# aos_client = LLMBotOpenSearchClient(aos_endpoint)
ws_client = None
-# get aos_index_dict
class APIException(Exception):
def __init__(self, message, code: str = None):
@@ -108,6 +64,7 @@ def load_ws_client():
def handle_error(func):
"""Decorator for exception handling"""
+
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
@@ -123,962 +80,6 @@ def wrapper(*args, **kwargs):
return wrapper
-def get_faq_answer(source, index_name):
- opensearch_query_response = aos_client.search(
- index_name=index_name,
- query_type="basic",
- query_term=source,
- field="metadata.source",
- )
- for r in opensearch_query_response["hits"]["hits"]:
- if r["_source"]["metadata"]["field"] == "answer":
- return r["_source"]["content"]
- return ""
-
-
-def get_faq_content(source, index_name):
- opensearch_query_response = aos_client.search(
- index_name=index_name,
- query_type="basic",
- query_term=source,
- field="metadata.source",
- )
- for r in opensearch_query_response["hits"]["hits"]:
- if r["_source"]["metadata"]["field"] == "all_text":
- return r["_source"]["content"]
- return ""
-
-
-def organize_faq_results(response, index_name):
- """
- Organize results from aos response
-
- :param query_type: query type
- :param response: aos response json
- """
- results = []
- if not response:
- return results
- aos_hits = response["hits"]["hits"]
- for aos_hit in aos_hits:
- result = {}
- try:
- result["source"] = aos_hit["_source"]["metadata"]["source"]
- result["score"] = aos_hit["_score"]
- result["detail"] = aos_hit["_source"]
- result["content"] = aos_hit["_source"]["content"]
- result["answer"] = get_faq_answer(result["source"], index_name)
- result["doc"] = get_faq_content(result["source"], index_name)
- except:
- logger.info("index_error")
- logger.info(aos_hit["_source"])
- continue
- # result.update(aos_hit["_source"])
- results.append(result)
- return results
-
-
-def get_ug_content(source, index_name):
- opensearch_query_response = aos_client.search(
- index_name=index_name,
- query_type="basic",
- query_term=source,
- field="metadata.source",
- size=100,
- )
- for r in opensearch_query_response["hits"]["hits"]:
- if r["_source"]["metadata"]["field"] == "all_text":
- return r["_source"]["content"]
- return ""
-
-
-def organize_ug_results(response, index_name):
- """
- Organize results from aos response
-
- :param query_type: query type
- :param response: aos response json
- """
- results = []
- aos_hits = response["hits"]["hits"]
- for aos_hit in aos_hits:
- result = {}
- result["source"] = aos_hit["_source"]["metadata"]["source"]
- result["score"] = aos_hit["_score"]
- result["detail"] = aos_hit["_source"]
- result["content"] = aos_hit["_source"]["content"]
- result["doc"] = get_ug_content(result["source"], index_name)
- # result.update(aos_hit["_source"])
- results.append(result)
- return results
-
-def remove_redundancy_debug_info(results):
- filtered_results = copy.deepcopy(results)
- for result in filtered_results:
- for field in list(result["detail"].keys()):
- if field.endswith("embedding") or field.startswith("vector"):
- del result["detail"][field]
- return filtered_results
-
-def parse_query(
- query_input: str,
- history: list,
- zh_embedding_model_endpoint: str,
- en_embedding_model_endpoint: str,
- debug_info: dict,
-):
- start = time.time()
- # concatenate query_input and history to unified prompt
- query_knowledge = "".join([query_input] + [row[0] for row in history][::-1])
-
- # get query embedding
- parsed_query = run_preprocess(query_knowledge)
- debug_info["query_parser_info"] = parsed_query
- if parsed_query["query_lang"] == "zh":
- parsed_query["zh_query"] = query_knowledge
- parsed_query["en_query"] = parsed_query["translated_text"]
- elif parsed_query["query_lang"] == "en":
- parsed_query["zh_query"] = parsed_query["translated_text"]
- parsed_query["en_query"] = query_knowledge
- zh_query_similarity_embedding_prompt = parsed_query["zh_query"]
- en_query_similarity_embedding_prompt = parsed_query["en_query"]
- zh_query_relevance_embedding_prompt = (
- "为这个句子生成表示以用于检索相关文章:" + parsed_query["zh_query"]
- )
- en_query_relevance_embedding_prompt = (
- "Represent this sentence for searching relevant passages: "
- + parsed_query["en_query"]
- )
- parsed_query["zh_query_similarity_embedding"] = SagemakerEndpointVectorOrCross(
- prompt=zh_query_similarity_embedding_prompt,
- endpoint_name=zh_embedding_model_endpoint,
- region_name=region,
- model_type="vector",
- stop=None,
- )
- parsed_query["zh_query_relevance_embedding"] = SagemakerEndpointVectorOrCross(
- prompt=zh_query_relevance_embedding_prompt,
- endpoint_name=zh_embedding_model_endpoint,
- region_name=region,
- model_type="vector",
- stop=None,
- )
- parsed_query["en_query_similarity_embedding"] = SagemakerEndpointVectorOrCross(
- prompt=en_query_similarity_embedding_prompt,
- endpoint_name=en_embedding_model_endpoint,
- region_name=region,
- model_type="vector",
- stop=None,
- )
- parsed_query["en_query_relevance_embedding"] = SagemakerEndpointVectorOrCross(
- prompt=en_query_relevance_embedding_prompt,
- endpoint_name=en_embedding_model_endpoint,
- region_name=region,
- model_type="vector",
- stop=None,
- )
- elpase_time = time.time() - start
- logger.info(f"runing time of parse query: {elpase_time}s seconds")
- return parsed_query
-
-
-def q_q_match(parsed_query, debug_info):
- start = time.time()
- opensearch_knn_results = []
- opensearch_knn_response = aos_client.search(
- index_name=aos_faq_index,
- query_type="knn",
- query_term=parsed_query["zh_query_similarity_embedding"],
- field="embedding",
- size=2,
- )
- opensearch_knn_results.extend(
- organize_faq_results(opensearch_knn_response, aos_faq_index)
- )
- opensearch_knn_response = aos_client.search(
- index_name=aos_faq_index,
- query_type="knn",
- query_term=parsed_query["en_query_similarity_embedding"],
- field="embedding",
- size=2,
- )
- opensearch_knn_results.extend(
- organize_faq_results(opensearch_knn_response, aos_faq_index)
- )
- # logger.info(json.dumps(opensearch_knn_response, ensure_ascii=False))
- elpase_time = time.time() - start
- logger.info(f"runing time of opensearch_knn : {elpase_time}s seconds")
- answer = None
- sources = None
- if len(opensearch_knn_results) > 0:
- debug_info["q_q_match_info"] = remove_redundancy_debug_info(
- opensearch_knn_results[:3]
- )
- if opensearch_knn_results[0]["score"] >= 0.9:
- source = opensearch_knn_results[0]["source"]
- answer = opensearch_knn_results[0]["answer"]
- sources = [source]
- return answer, sources
- return answer, sources
-
-
-def get_relevant_documents_dgr(
- parsed_query,
- rerank_model_endpoint: str,
- aos_faq_index: str,
- aos_ug_index: str,
- debug_info,
-):
- # 1. get AOS knn recall
- faq_result_num = 2
- ug_result_num = 20
- start = time.time()
- opensearch_knn_results = []
- opensearch_knn_response = aos_client.search(
- index_name=aos_faq_index,
- query_type="knn",
- query_term=parsed_query["zh_query_relevance_embedding"],
- field="embedding",
- size=faq_result_num,
- )
- opensearch_knn_results.extend(
- organize_faq_results(opensearch_knn_response, aos_faq_index)[:faq_result_num]
- )
- opensearch_knn_response = aos_client.search(
- index_name=aos_faq_index,
- query_type="knn",
- query_term=parsed_query["en_query_relevance_embedding"],
- field="embedding",
- size=faq_result_num,
- )
- opensearch_knn_results.extend(
- organize_faq_results(opensearch_knn_response, aos_faq_index)[:faq_result_num]
- )
- # logger.info(json.dumps(opensearch_knn_response, ensure_ascii=False))
- faq_recall_end_time = time.time()
- elpase_time = faq_recall_end_time - start
- logger.info(f"runing time of faq recall : {elpase_time}s seconds")
- filter = None
- if parsed_query["is_api_query"]:
- filter = [{"term": {"metadata.is_api": True}}]
-
- opensearch_knn_response = aos_client.search(
- index_name=aos_ug_index,
- query_type="knn",
- query_term=parsed_query["zh_query_relevance_embedding"],
- field="embedding",
- filter=filter,
- size=ug_result_num,
- )
- opensearch_knn_results.extend(
- organize_ug_results(opensearch_knn_response, aos_ug_index)[:ug_result_num]
- )
- opensearch_knn_response = aos_client.search(
- index_name=aos_ug_index,
- query_type="knn",
- query_term=parsed_query["en_query_relevance_embedding"],
- field="embedding",
- filter=filter,
- size=ug_result_num,
- )
- opensearch_knn_results.extend(
- organize_ug_results(opensearch_knn_response, aos_ug_index)[:ug_result_num]
- )
-
- debug_info["knowledge_qa_knn_recall"] = remove_redundancy_debug_info(
- opensearch_knn_results
- )
- ug_recall_end_time = time.time()
- elpase_time = ug_recall_end_time - faq_recall_end_time
- logger.info(f"runing time of ug recall: {elpase_time}s seconds")
-
- # 2. get AOS invertedIndex recall
- opensearch_query_results = []
-
- # 3. combine these two opensearch_knn_response and opensearch_query_response
- recall_knowledge = combine_recalls(opensearch_knn_results, opensearch_query_results)
-
- rerank_pair = []
- for knowledge in recall_knowledge:
- # rerank_pair.append([parsed_query["query"], knowledge["content"]][:1024])
- rerank_pair.append(
- [parsed_query["en_query"], knowledge["content"]][: 1024 * 10]
- )
- en_score_list = json.loads(
- SagemakerEndpointVectorOrCross(
- prompt=json.dumps(rerank_pair),
- endpoint_name=rerank_model_endpoint,
- region_name=region,
- model_type="rerank",
- stop=None,
- )
- )
- rerank_pair = []
- for knowledge in recall_knowledge:
- # rerank_pair.append([parsed_query["query"], knowledge["content"]][:1024])
- rerank_pair.append(
- [parsed_query["zh_query"], knowledge["content"]][: 1024 * 10]
- )
- zh_score_list = json.loads(
- SagemakerEndpointVectorOrCross(
- prompt=json.dumps(rerank_pair),
- endpoint_name=rerank_model_endpoint,
- region_name=region,
- model_type="rerank",
- stop=None,
- )
- )
- rerank_knowledge = []
- for knowledge, score in zip(recall_knowledge, zh_score_list):
- # if score > 0:
- new_knowledge = knowledge.copy()
- new_knowledge["rerank_score"] = score
- rerank_knowledge.append(new_knowledge)
- for knowledge, score in zip(recall_knowledge, en_score_list):
- # if score > 0:
- new_knowledge = knowledge.copy()
- new_knowledge["rerank_score"] = score
- rerank_knowledge.append(new_knowledge)
- rerank_knowledge.sort(key=lambda x: x["rerank_score"], reverse=True)
- debug_info["knowledge_qa_rerank"] = rerank_knowledge
-
- rerank_end_time = time.time()
- elpase_time = rerank_end_time - ug_recall_end_time
- logger.info(f"runing time of rerank: {elpase_time}s seconds")
-
- return rerank_knowledge
-
-
-def dgr_entry(
- session_id: str,
- query_input: str,
- history: list,
- zh_embedding_model_endpoint: str,
- en_embedding_model_endpoint: str,
- cross_model_endpoint: str,
- rerank_model_endpoint: str,
- llm_model_endpoint: str,
- aos_faq_index: str,
- aos_ug_index: str,
- enable_knowledge_qa: bool,
- temperature: float,
- enable_q_q_match: bool,
- llm_model_id=None,
- stream=False,
-):
- """
- Entry point for the Lambda function.
-
- :param session_id: The ID of the session.
- :param query_input: The query input.
- :param history: The history of the conversation.
- :param embedding_model_endpoint: The endpoint of the embedding model.
- :param cross_model_endpoint: The endpoint of the cross model.
- :param llm_model_endpoint: The endpoint of the language model.
- :param llm_model_name: The name of the language model.
- :param aos_faq_index: The faq index of the AOS engine.
- :param aos_ug_index: The ug index of the AOS engine.
- :param enable_knowledge_qa: Whether to enable knowledge QA.
- :param temperature: The temperature of the language model.
- :param stream(Bool): Whether to use llm stream decoding output.
-
- return: answer(str)
- """
- debug_info = {
- "query": query_input,
- "query_parser_info": {},
- "q_q_match_info": {},
- "knowledge_qa_knn_recall": {},
- "knowledge_qa_boolean_recall": {},
- "knowledge_qa_combined_recall": {},
- "knowledge_qa_cross_model_sort": {},
- "knowledge_qa_llm": {},
- "knowledge_qa_rerank": {},
- }
- contexts = []
- sources = []
- answer = ""
- try:
- # 1. parse query
- parsed_query = parse_query(
- query_input,
- history,
- zh_embedding_model_endpoint,
- en_embedding_model_endpoint,
- debug_info,
- )
- # 2. query question match
- if enable_q_q_match:
- answer, sources = q_q_match(parsed_query, debug_info)
- if answer and sources:
- return answer, sources, contexts, debug_info
- # 3. recall and rerank
- knowledges = get_relevant_documents_dgr(
- parsed_query,
- rerank_model_endpoint,
- aos_faq_index,
- aos_ug_index,
- debug_info,
- )
- context_num = 6
- sources = list(set([item["source"] for item in knowledges[:context_num]]))
- contexts = knowledges[:context_num]
- # 4. generate answer using question and recall_knowledge
- parameters = {"temperature": temperature}
- generate_input = dict(
- model_id=llm_model_id,
- query=query_input,
- contexts=knowledges[:context_num],
- history=history,
- region_name=region,
- model_kwargs=parameters,
- context_num=context_num,
- model_type="answer",
- llm_model_endpoint=llm_model_endpoint,
- stream=stream,
- )
-
- llm_start_time = time.time()
- llm_chain = get_rag_llm_chain(**generate_input)
- llm_chain.invoke()
-
- answer = llm_generate(**generate_input)
- llm_end_time = time.time()
- elpase_time = llm_end_time - llm_start_time
- logger.info(f"runing time of llm: {elpase_time}s seconds")
- # answer = ret["answer"]
- debug_info["knowledge_qa_llm"] = answer
- except Exception as e:
- logger.info(f"Exception Query: {query_input}")
- logger.info(f"{traceback.format_exc()}")
- answer = ""
-
- # 5. update_session
- # start = time.time()
- # update_session(session_id=session_id, chat_session_table=chat_session_table,
- # question=query_input, answer=answer, knowledge_sources=sources)
- # elpase_time = time.time() - start
- # logger.info(f'runing time of update_session : {elpase_time}s seconds')
-
- return answer, sources, contexts, debug_info
-
-
-def get_strict_qq_chain(strict_q_q_index):
- def get_strict_qq_result(docs, threshold=0.7):
- results = []
- for doc in docs:
- if doc.metadata["score"] < threshold:
- break
- results.append(
- {
- "score": doc.metadata["score"],
- "source": doc.metadata["source"],
- "answer": doc.metadata["answer"],
- "question": doc.metadata["question"],
- }
- )
- return results
-
- mkt_q_q_retriever = QueryQuestionRetriever(
- index=strict_q_q_index,
- vector_field="vector_field",
- source_field="file_path",
- size=5,
- )
- strict_q_q_chain = mkt_q_q_retriever | RunnableLambda(get_strict_qq_result)
- return strict_q_q_chain
-
-
-def return_strict_qq_result(x):
- # def get_strict_qq_result(docs, threshold=0.7):
- # results = []
- # for doc in docs:
- # results.append({"score": doc.metadata["score"],
- # "source": doc.metadata["source"],
- # "answer": doc.metadata["answer"],
- # "question": doc.metadata["question"]})
- # output = {"answer": json.dumps(results, ensure_ascii=False), "sources": [], "contexts": []}
- # return output
- # return get_strict_qq_result(x["intent_info"]["strict_qq_intent_result"])
- return {
- "answer": json.dumps(
- x["intent_info"]["strict_qq_intent_result"], ensure_ascii=False
- ),
- "sources": [],
- "contexts": [],
- "context_docs": [],
- "context_sources": [],
- }
-
-
-def get_rag_llm_chain(rag_config, stream):
- def contexts_trunc(docs: list, context_num=2):
- # print('docs len',len(docs))
- docs = [doc for doc in docs[:context_num]]
- # the most related doc will be placed last
- docs.sort(key=lambda x: x.metadata["score"])
- # filter same docs
- s = set()
- context_strs = []
- context_docs = []
- context_sources = []
- for doc in docs:
- content = doc.page_content
- if content not in s:
- context_strs.append(content)
- s.add(content)
- context_docs.append({
- "doc": content,
- "source": doc.metadata["source"],
- "score": doc.metadata["score"]
- })
- context_sources.append(doc.metadata["source"])
- # print(len(context_docs))
- # print(sg)
- return {
- "contexts": context_strs,
- "context_docs": context_docs,
- "context_sources":context_sources
- }
-
- generator_llm_config = rag_config['generator_llm_config']
- # TODO opt with efficiency
- context_num = generator_llm_config['context_num']
- contexts_trunc_stage = RunnablePassthrough.assign(
- contexts=lambda x: contexts_trunc(x["docs"], context_num=context_num)['contexts'],
- context_docs=lambda x: contexts_trunc(x["docs"], context_num=context_num)['context_docs'],
- context_sources=lambda x: contexts_trunc(x["docs"], context_num=context_num)['context_sources'],
- )
- other_llm_config = copy.deepcopy(generator_llm_config)
- other_llm_config.pop('model_id')
- other_llm_config.pop('model_kwargs')
- llm_chain = get_llm_chain(
- model_id=generator_llm_config['model_id'],
- intent_type=IntentType.KNOWLEDGE_QA.value,
- model_kwargs=generator_llm_config['model_kwargs'], # TODO
- stream=stream,
- # chat_history=rag_config['chat_history'],
- **other_llm_config
- )
- llm_chain = contexts_trunc_stage |\
- RunnablePassthrough.assign(chat_history=lambda x:rag_config['chat_history']) |\
- RunnablePassthrough.assign(answer=llm_chain)
- return llm_chain
-
-def get_qd_chain(
- aos_index_list, retriever_top_k=10, reranker_top_k=5, using_whole_doc=True, chunk_num=0, enable_reranker=True
-):
- retriever_list = [
- QueryDocumentRetriever(
- index, "vector_field", "text", "file_path", using_whole_doc, chunk_num, retriever_top_k, "zh", zh_embedding_endpoint
- )
- for index in aos_index_list
- ] + [
- QueryDocumentRetriever(
- index, "vector_field", "text", "file_path", using_whole_doc, chunk_num, retriever_top_k, "en", en_embedding_endpoint
- )
- for index in aos_index_list
- ]
- lotr = MergerRetriever(retrievers=retriever_list)
- if enable_reranker:
- compressor = BGEReranker(top_n=reranker_top_k)
- else:
- compressor = MergeReranker(top_n=reranker_top_k)
- compression_retriever = ContextualCompressionRetriever(
- base_compressor=compressor, base_retriever=lotr
- )
- qd_chain = RunnablePassthrough.assign(docs=compression_retriever)
- return qd_chain
-
-def get_qd_llm_chain(
- aos_index_list,
- rag_config,
- stream=False,
- # top_n=5
-):
- using_whole_doc = rag_config['retriever_config']['using_whole_doc']
- chunk_num = rag_config['retriever_config']['chunk_num']
- retriever_top_k = rag_config['retriever_config']['retriever_top_k']
- reranker_top_k = rag_config['retriever_config']['reranker_top_k']
- enable_reranker = rag_config['retriever_config']['enable_reranker']
-
- llm_chain = get_rag_llm_chain(rag_config, stream)
- qd_chain = get_qd_chain(aos_index_list, using_whole_doc=using_whole_doc,
- chunk_num=chunk_num, retriever_top_k=retriever_top_k,
- reranker_top_k=reranker_top_k, enable_reranker=enable_reranker)
- qd_llm_chain = chain_logger(qd_chain, 'qd_retriever') | chain_logger(llm_chain,'llm_chain')
- return qd_llm_chain
-
-
-def get_chat_llm_chain(
- rag_config,
- stream=False
- ):
- generator_llm_config = rag_config['generator_llm_config']
- other_llm_config = copy.deepcopy(generator_llm_config)
- other_llm_config.pop('model_id')
- other_llm_config.pop('model_kwargs')
-
- chat_llm_chain = get_llm_chain(
- model_id=generator_llm_config['model_id'],
- intent_type=IntentType.CHAT.value,
- model_kwargs=generator_llm_config['model_kwargs'], # TODO
- stream=stream,
- # chat_history=rag_config['chat_history'],
- **other_llm_config
- ) | {
- "answer": lambda x: x,
- "sources": lambda x: [],
- "contexts": lambda x: [],
- "intent_type": lambda x: IntentType.CHAT.value,
- "context_docs": lambda x: [],
- "context_sources": lambda x: [],
- }
- return chat_llm_chain
-
-def market_chain_entry(
- query_input: str,
- stream=False,
- manual_input_intent=None,
- rag_config=None
-):
- """
- Entry point for the Lambda function.
-
- :param query_input: The query input.
- :param aos_index: The index of the AOS engine.
- :param stream(Bool): Whether to use llm stream decoding output.
- return: answer(str)
- """
- assert rag_config is not None
- generator_llm_config = rag_config['generator_llm_config']
- intent_type = rag_config['intent_config']['intent_type']
- aos_index_dict = json.loads(
- os.environ.get(
- "aos_index_dict",
- '{"aos_index_mkt_qd":"aws-cn-mkt-knowledge","aos_index_mkt_qq":"gcr-mkt-qq","aos_index_dgr_qd":"ug-index","aos_index_dgr_qq":"faq-index-2"}',
- )
- )
- aos_index_mkt_qd = aos_index_dict["aos_index_mkt_qd"]
- aos_index_mkt_qq = aos_index_dict["aos_index_mkt_qq"]
- aos_index_dgr_qd = aos_index_dict["aos_index_dgr_qd"]
- aos_index_dgr_faq_qd = aos_index_dict["aos_index_dgr_faq_qd"]
- aos_index_dgr_qq = aos_index_dict["aos_index_dgr_qq"]
-
- # debug_info = {
- # "query": query_input,
- # "query_parser_info": {},
- # "q_q_match_info": {},
- # "knowledge_qa_knn_recall": {},
- # "knowledge_qa_boolean_recall": {},
- # "knowledge_qa_combined_recall": {},
- # "knowledge_qa_cross_model_sort": {},
- # "knowledge_qa_llm": {},
- # "knowledge_qa_rerank": {},
- # }
- debug_info = {}
- contexts = []
- sources = []
- answer = ""
- intent_info = {
- "manual_input_intent": manual_input_intent,
- "strict_qq_intent_result": {},
- }
-
- # 1. Strict Query Question Intent
- # 1.1. strict query question retrieval.
- # strict_q_q_chain = get_strict_qq_chain(aos_index_mkt_qq)
-
- # 2. Knowledge QA Intent
- # 2.1 query question retrieval.
- dgr_q_q_retriever = QueryQuestionRetriever(
- index=aos_index_dgr_qq,
- vector_field="vector_field",
- source_field="source",
- size=5,
- lang="zh",
- embedding_model_endpoint=zh_embedding_endpoint
- )
- # 2.2 query document retrieval + LLM.
- # qd_llm_chain = get_qd_llm_chain(
- # [aos_index_dgr_qd, aos_index_dgr_faq_qd, aos_index_mkt_qd],
- # rag_config,
- # stream,
- # # top_n=5,
- # # chunk_num=0
- # )
-
- # 2.3 query question router.
- def qq_route(info, threshold=0.9):
- for doc in info["qq_result"]:
- if doc.metadata["score"] > threshold:
- output = {
- "answer": doc.metadata["answer"],
- "sources": doc.metadata["source"],
- "contexts": [],
- "context_docs": [],
- "context_sources": [],
- # "debug_info": lambda x: x["debug_info"],
- }
- logger.info('qq matched...')
- info.update(output)
- return info
- qd_llm_chain = get_qd_llm_chain(
- [aos_index_dgr_qd, aos_index_dgr_faq_qd, aos_index_mkt_qd],
- rag_config,
- stream,
- # top_n=5,
- # chunk_num=0
- )
- return qd_llm_chain
-
- qq_chain = RunnablePassthrough.assign(qq_result=dgr_q_q_retriever)
- qq_chain = chain_logger(qq_chain,'qq_chain')
- qq_qd_llm_chain = qq_chain | RunnableLambda(qq_route)
-
- # TODO design chat chain
- # other_llm_config = copy.deepcopy(generator_llm_config)
- # other_llm_config.pop('model_id')
- # other_llm_config.pop('model_kwargs')
-
-
- # chat_llm_chain = get_llm_chain(
- # model_id=generator_llm_config['model_id'],
- # intent_type=IntentType.CHAT.value,
- # model_kwargs=generator_llm_config['model_kwargs'], # TODO
- # stream=stream,
- # # chat_history=rag_config['chat_history'],
- # **other_llm_config
- # ) | {
- # "answer": lambda x: x,
- # "sources": lambda x: [],
- # "contexts": lambda x: [],
- # "intent_type": lambda x: IntentType.CHAT.value,
- # "context_docs": lambda x: [],
- # "context_sources": lambda x: [],
- # }
-
- # query process chain
- query_process_chain = get_query_process_chain(
- rag_config['chat_history'],
- rag_config['query_process_config']
- )
- # | add_key_to_debug(add_key='conversation_query_rewrite',debug_key="debug_info")
- # | add_key_to_debug(add_key='query_rewrite',debug_key="debug_info")
-
- # query_rewrite_chain = chain_logger(
- # query_rewrite_chain,
- # "query rewrite module"
- # )
- # intent recognition
- intent_recognition_chain = auto_intention_recoginition_chain(
- q_q_retriever_config={
- "index_q_q":aos_index_mkt_qq,
- 'lang':'zh',
- 'embedding_endpoint':zh_embedding_endpoint,
- "q_q_match_threshold": rag_config['retriever_config']['q_q_match_threshold']
- },
- intent_config=rag_config['intent_config']
- )
-
- intent_recognition_chain = chain_logger(
- intent_recognition_chain,
- 'intention module',
- log_output_template='intent chain output: {intent_type}'
- )
-
- full_chain = query_process_chain | intent_recognition_chain | RunnableBranch(
- (lambda x:x['intent_type'] == IntentType.KNOWLEDGE_QA.value, qq_qd_llm_chain),
- (lambda x:x['intent_type'] == IntentType.STRICT_QQ.value, return_strict_qq_result),
- # (lambda x:x['intent_type'] == IntentType.STRICT_QQ.value, strict_q_q_chain),
- get_chat_llm_chain(rag_config=rag_config,stream=stream), # chat
- )
- # full_chain = intent_recognition_chain
- # full_chain = RunnableLambda(route)
- response = asyncio.run(full_chain.ainvoke(
- {
- "query": query_input,
- "debug_info": debug_info,
- "intent_type": intent_type,
- "intent_info": intent_info,
- "chat_history": rag_config['chat_history']
- }
- ))
-
- answer = response["answer"]
- sources = response["context_sources"]
- contexts = response["context_docs"]
-
- return answer, sources, contexts, debug_info
-
-def market_conversation_summary_entry(
- messages:list[dict],
- rag_config=None,
- stream=False
- ):
-
- if not rag_config['chat_history']:
- assert messages,messages
- chat_history = []
- for message in messages:
- role = message['role']
- content = message['content']
- assert role in ['user','ai']
- if role == 'user':
- chat_history.append(HumanMessage(content=content))
- else:
- chat_history.append(AIMessage(content=content))
- rag_config['chat_history'] = chat_history
-
- else:
- # filter by the window time
- time_window = rag_config.get('time_window',{})
- start_time = time_window.get('start_time',-math.inf)
- end_time = time_window.get('end_time',math.inf)
- assert isinstance(start_time, float) and isinstance(end_time, float), (start_time, end_time)
- chat_history = rag_config['chat_history']
- chat_history = filter_chat_history_by_time(chat_history,start_time=start_time,end_time=end_time)
- rag_config['chat_history'] = chat_history
- # rag_config['intent_config']['intent_type'] = IntentType.CHAT.value
-
- # query_input = """请简要总结上述对话中的内容,每一个对话单独一个总结,并用 '- '开头。 每一个总结要先说明问题。\n"""
- mkt_conversation_summary_config = rag_config["mkt_conversation_summary_config"]
- llm_chain = get_llm_chain(
- intent_type=MKT_CONVERSATION_SUMMARY_TYPE,
- stream=stream,
- **mkt_conversation_summary_config,
- )
- response = llm_chain.invoke({
- "chat_history": rag_config['chat_history'],
- })
- return response, [], {}, {}
-
-@timeit
-def main_qd_retriever_entry(
- query_input: str,
- aos_index: str,
- rag_config=None,
- manual_input_intent=None
-):
- """
- Entry point for the Lambda function.
-
- :param query_input: The query input.
- :param aos_index: The index of the AOS engine.
-
- return: answer(str)
- """
- debug_info = {
- "query": query_input,
- "query_parser_info": {},
- "q_q_match_info": {},
- "knowledge_qa_knn_recall": {},
- "knowledge_qa_boolean_recall": {},
- "knowledge_qa_combined_recall": {},
- "knowledge_qa_cross_model_sort": {},
- "knowledge_qa_llm": {},
- "knowledge_qa_rerank": {},
- }
- retriever_top_k = rag_config['retriever_config']['retriever_top_k']
- using_whole_doc = rag_config['retriever_config']['using_whole_doc']
- chunk_num = rag_config['retriever_config']['chunk_num']
- query_process_chain = get_query_process_chain(
- rag_config['chat_history'],
- rag_config['query_process_config']['query_rewrite_config'],
- rag_config['query_process_config']['conversation_query_rewrite_config'],
- rag_config['query_process_config']['hyde_config']
- )
- intent_type = rag_config['intent_config']['intent_type']
- intent_info = {
- "manual_input_intent": manual_input_intent,
- "strict_qq_intent_result": {},
- }
- intent_recognition_chain = auto_intention_recoginition_chain("aos_index_mkt_qq")
- intent_recognition_chain = chain_logger(
- intent_recognition_chain,
- 'intention module',
- log_output_template='intent chain output: {intent_type}'
-
- )
- qd_chain = get_qd_chain(
- [aos_index], using_whole_doc=using_whole_doc, chunk_num=chunk_num, retriever_top_k=retriever_top_k, reranker_top_k=10
- )
- full_chain = query_process_chain | intent_recognition_chain | qd_chain
- response = asyncio.run(full_chain.ainvoke({
- "query": query_input,
- "debug_info": debug_info,
- "intent_type": intent_type,
- "intent_info": intent_info,
- }))
- doc_list = []
- for doc in response["docs"]:
- doc_list.append({"page_content": doc.page_content, "metadata": doc.metadata})
- return doc_list, debug_info
-
-def main_qq_retriever_entry(
- query_input: str,
- aos_index: str,
-):
- """
- Entry point for the Lambda function.
-
- :param query_input: The query input.
- :param aos_index: The index of the AOS engine.
-
- return: answer(str)
- """
- debug_info = {
- "query": query_input,
- "query_parser_info": {},
- "q_q_match_info": {},
- "knowledge_qa_knn_recall": {},
- "knowledge_qa_boolean_recall": {},
- "knowledge_qa_combined_recall": {},
- "knowledge_qa_cross_model_sort": {},
- "knowledge_qa_llm": {},
- "knowledge_qa_rerank": {},
- }
- full_chain = get_strict_qq_chain(aos_index)
- response = full_chain.invoke({"query": query_input, "debug_info": debug_info})
- return response
-
-def main_chain_entry(
- query_input: str,
- aos_index: str,
- stream=False,
- rag_config=None
-):
- """
- Entry point for the Lambda function.
-
- :param query_input: The query input.
- :param aos_index: The index of the AOS engine.
-
- return: answer(str)
- """
- debug_info = {
- "query": query_input,
- "query_parser_info": {},
- "q_q_match_info": {},
- "knowledge_qa_knn_recall": {},
- "knowledge_qa_boolean_recall": {},
- "knowledge_qa_combined_recall": {},
- "knowledge_qa_cross_model_sort": {},
- "knowledge_qa_llm": {},
- "knowledge_qa_rerank": {},
- }
- contexts = []
- sources = []
- answer = ""
- full_chain = get_qd_llm_chain(
- [aos_index], rag_config, stream
- )
- response = full_chain.invoke({"query": query_input, "debug_info": debug_info})
- answer = response["answer"]
- sources = response["context_sources"]
- contexts = response["context_docs"]
- return answer, sources, contexts, debug_info
-
def _is_websocket_request(event):
"""Check if the request is WebSocket or Restful
@@ -1094,17 +95,6 @@ def _is_websocket_request(event):
else:
return False
-def get_retriever_response(docs, debug_info):
- response = {"statusCode": 200, "headers": {"Content-Type": "application/json"}}
- resp_header = {
- "Content-Type": "application/json",
- "Access-Control-Allow-Headers": "Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token",
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Methods": "*",
- }
- response["body"] = json.dumps({"docs": docs, "debug_info": debug_info})
- response["headers"] = resp_header
- return response
# @handle_error
def lambda_handler(event, context):
@@ -1120,83 +110,96 @@ def lambda_handler(event, context):
# Get request body
event_body = json.loads(record_event["body"])
# model = event_body['model']
- session_id = event_body.get("session_id",None) or "N/A"
- messages = event_body.get("messages",[])
-
+ # session_id = event_body.get("session_id", None) or "N/A"
+ messages = event_body.get("messages", [])
+
# deal with stream parameter
stream = _is_websocket_request(record_event)
if stream:
load_ws_client()
logger.info(f"stream decode: {stream}")
- biz_type = event_body.get("type", Type.COMMON.value)
+ # biz_type = event_body.get("type", Type.COMMON.value)
+ client_type = event_body.get("client_type", "default_client_type")
enable_q_q_match = event_body.get("enable_q_q_match", False)
+ entry_type = event_body.get("type", Type.COMMON.value).lower()
+ # enable_q_q_match = event_body.get("enable_q_q_match", False)
enable_debug = event_body.get("enable_debug", False)
-
get_contexts = event_body.get("get_contexts", False)
+ session_id = event_body.get("session_id", None)
+ ws_connection_id = None
# all rag related params can be found in rag_config
- rag_config = parse_config.parse_rag_config(event_body)
+ # rag_config = parse_config.parse_rag_config(event_body)
- debug_level = int(rag_config['debug_level'])
+ debug_level = event_body.get('debug_level',logging.INFO)
logger.setLevel(debug_level)
- if messages and biz_type.lower() != Type.MARKET_CONVERSATION_SUMMARY.value:
+ if messages and entry_type != Type.MARKET_CONVERSATION_SUMMARY.value:
assert len(messages) == 1
- question = messages[-1]['content']
+ question = messages[-1]["content"]
+ custom_message_id = messages[-1].get("custom_message_id", None)
else:
- question = "" # MARKET_CONVERSATION_SUMMARY
+ question = "" # MARKET_CONVERSATION_SUMMARY
+ custom_message_id = event.get("custom_message_id", None)
# _, question = process_input_messages(messages)
# role = "user"
-
- if session_id == 'N/A':
- rag_config['session_id'] = f"session_{int(request_timestamp)}"
+
+ if not session_id:
+ session_id = f"session_{int(request_timestamp)}"
if stream:
- rag_config['ws_connection_id'] = record_event["requestContext"]["connectionId"]
+ ws_connection_id = record_event["requestContext"]["connectionId"]
+ # get chat history
user_id = event_body.get("user_id", "default_user_id")
message_id = str(uuid.uuid4())
- chat_history = DynamoDBChatMessageHistory(
+ ddb_history_obj = DynamoDBChatMessageHistory(
table_name=chat_session_table,
- session_id=rag_config['session_id'],
+ session_id=session_id,
user_id=user_id,
+ client_type=client_type,
)
- history_messages = chat_history.message_as_langchain
- rag_config['chat_history'] = history_messages
- logger.info(f'rag configs:\n {json.dumps(rag_config,indent=2,ensure_ascii=False,cls=JSONEncoder)}')
+ # print(chat_session_table,session_id,DynamoDBChatMessageHistory)
+ chat_history = ddb_history_obj.message_as_langchain
+
+ event_body['chat_history'] = chat_history
+ event_body['ws_connection_id'] = ws_connection_id
+ event_body['session_id'] = session_id
+ event_body['debug_level'] = debug_level
+
+ # logger.info(f'rag configs:\n {json.dumps(rag_config,indent=2,ensure_ascii=False,cls=JSONEncoder)}')
#
# knowledge_qa_flag = True if model == "knowledge_qa" else False
-
+
main_entry_start = time.time()
contexts = []
- if biz_type.lower() == Type.COMMON.value:
+ # entry_type = biz_type.lower()
+ if entry_type == Type.COMMON.value:
answer, sources, contexts, debug_info = main_chain_entry(
question,
aos_index,
stream=stream,
- rag_config=rag_config
+ event_body=event_body,
)
- elif biz_type.lower() == Type.QD_RETRIEVER.value:
+ elif entry_type == Type.QD_RETRIEVER.value:
retriever_index = event_body.get("retriever_index", aos_index)
docs, debug_info = main_qd_retriever_entry(
question,
retriever_index,
- rag_config=rag_config
+ event_body=event_body,
+ message_id=custom_message_id
)
return get_retriever_response(docs, debug_info)
- elif biz_type.lower() == Type.QQ_RETRIEVER.value:
+ elif entry_type == Type.QQ_RETRIEVER.value:
retriever_index = event_body.get("retriever_index", aos_index)
- docs = main_qq_retriever_entry(
- question,
- retriever_index
- )
+ docs = main_qq_retriever_entry(question, retriever_index)
return get_retriever_response(docs)
- elif biz_type.lower() == Type.DGR.value:
+ elif entry_type == Type.DGR.value:
history = []
- model = event_body.get('model','chat')
- temperature = event_body.get('temperature',0.5)
+ model = event_body.get("model", "chat")
+ temperature = event_body.get("temperature", 0.5)
knowledge_qa_flag = True if model == "knowledge_qa" else False
answer, sources, contexts, debug_info = dgr_entry(
session_id,
@@ -1214,28 +217,59 @@ def lambda_handler(event, context):
enable_q_q_match,
stream=stream,
)
- elif biz_type.lower() == Type.MARKET_CHAIN.value:
- answer, sources, contexts, debug_info = market_chain_entry(
+ elif entry_type == Type.MARKET_CHAIN_CORE.value:
+ answer, sources, contexts, debug_info = market_chain_entry_core(
+ question,
+ stream=stream,
+ event_body=event_body,
+ message_id=custom_message_id
+ )
+ elif entry_type == Type.MARKET_CHAIN.value:
+ answer, sources, contexts, debug_info = market_chain_knowledge_entry(
+ question,
+ stream=stream,
+ event_body=event_body,
+ message_id=custom_message_id
+ )
+ # answer, sources, contexts, debug_info = market_chain_entry(
+ # question,
+ # stream=stream,
+ # event_body=event_body,
+ # message_id=custom_message_id
+ # )
+
+ elif entry_type == Type.MARKET_CHAIN_KNOWLEDGE.value:
+ answer, sources, contexts, debug_info = market_chain_knowledge_entry(
+ question,
+ stream=stream,
+ event_body=event_body,
+ message_id=custom_message_id
+ )
+ elif entry_type == "market_chain_knowledge_langgraph":
+ answer, sources, contexts, debug_info = market_chain_knowledge_entry_langgraph(
question,
stream=stream,
- rag_config=rag_config
+ event_body=event_body,
+ message_id=custom_message_id
)
- elif biz_type.lower() == Type.MARKET_CONVERSATION_SUMMARY.value:
+
+ elif entry_type == Type.MARKET_CONVERSATION_SUMMARY.value:
answer, sources, contexts, debug_info = market_conversation_summary_entry(
messages=messages,
- rag_config=rag_config,
+ event_body=event_body,
stream=stream
)
-
- main_entry_elpase = time.time() - main_entry_start
- logger.info(f"runing time of {biz_type} entry : {main_entry_elpase}s seconds")
+
+ main_entry_end = time.time()
+ main_entry_elpase = main_entry_end - main_entry_start
+ logger.info(f"{custom_message_id} running time of main entry {entry_type} : {main_entry_elpase}s")
response_kwargs = dict(
stream=stream,
- session_id=rag_config['session_id'],
- ws_connection_id=rag_config['ws_connection_id'],
+ session_id=event_body['session_id'],
+ ws_connection_id=event_body['ws_connection_id'],
# model=model,
- entry_type=biz_type.lower(),
+ entry_type=entry_type,
question=question,
request_timestamp=request_timestamp,
answer=answer,
@@ -1245,12 +279,13 @@ def lambda_handler(event, context):
enable_debug=enable_debug,
debug_info=debug_info,
ws_client=ws_client,
- chat_history=chat_history,
+ ddb_history_obj=ddb_history_obj,
message_id=message_id,
+ client_type=client_type,
+ custom_message_id=custom_message_id,
+ main_entry_end=main_entry_end
)
- r = process_response(
- **response_kwargs
- )
+ r = process_response(**response_kwargs)
if not stream:
return r
- return {"statusCode": 200, "body": "All records have been processed"}
\ No newline at end of file
+ return {"statusCode": 200, "body": "All records have been processed"}
diff --git a/source/lambda/executor/requirements.txt b/source/lambda/executor/requirements.txt
index f61144b4..8b92aa03 100644
--- a/source/lambda/executor/requirements.txt
+++ b/source/lambda/executor/requirements.txt
@@ -1,4 +1,5 @@
-langchain==0.1.0
+langchain==0.1.11
+langgraph==0.0.26
langchainhub==0.1.14
opensearch-py==2.2.0
requests_aws4auth==1.2.2
@@ -7,3 +8,4 @@ requests_aws4auth==1.2.2
boto3==1.28.57
botocore==1.31.57
python-dateutil==2.8.2
+numpy==1.26.2
diff --git a/source/lambda/executor/test/app.py b/source/lambda/executor/test/app.py
index 8ee67ca9..4dadaa3f 100644
--- a/source/lambda/executor/test/app.py
+++ b/source/lambda/executor/test/app.py
@@ -2,6 +2,7 @@
import json
import requests
import boto3
+from websocket import create_connection
from executor_local_test import generate_answer
from langchain_community.document_loaders import UnstructuredPDFLoader
@@ -72,10 +73,26 @@ def check_data(url):
],
]
-
-def get_answer(query_input, entry_type):
+def generate_answer_from_api(url, query_input, type, rag_parameters):
+ data = {
+ "model": "knowledge_qa",
+ "messages": [{"role": "user", "content": query_input}],
+ "type": type,
+ }
+ data.update(rag_parameters)
+ response = requests.post(url, json.dumps(data))
+ return response
+
+def generate_answer_from_ws(url, query_input, type):
+ for i in range(max_debug_block):
+ with gr.Tab(visible=False) as tab:
+ tab_list.append(tab)
+ json_block = gr.JSON(visible=False)
+ json_list.append(json_block)
model_id = "internlm2-chat-7b"
- endpoint_name = "instruct-internlm2-chat-7b-f7dc2"
+ # endpoint_name = "instruct-internlm2-chat-7b-f7dc2"
+ # endpoint_name = "internlm2-chat-7b-2024-02-23-07-29-02-632"
+ endpoint_name = "internlm2-chat-20b-4bits-2024-02-29-05-37-42-885"
rag_parameters=dict(
query_process_config = {
"conversation_query_rewrite_config":{
@@ -91,7 +108,80 @@ def get_answer(query_input, entry_type):
"chunk_num": 0,
"using_whole_doc": False,
"enable_reranker": True,
- "retriever_top_k": 2
+ "retriever_top_k": 2,
+ "workspace_ids": ["aos_index_mkt_faq_qq", "aos_index_acts_qd"]
+ },
+ generator_llm_config ={
+ "model_kwargs":{
+ "max_new_tokens": 2000,
+ "temperature": 0.1,
+ "top_p": 0.9
+ },
+ "model_id": model_id,
+ "endpoint_name": endpoint_name,
+ "context_num": 2
+ })
+ sources = []
+ debug_info = []
+ ws = create_connection(url)
+ data = {
+ "model": "knowledge_qa",
+ "messages": [{"role": "user", "content": query_input}],
+ "type": type,
+ "get_contexts": True,
+ "enable_debug": True,
+ }
+ data.update(rag_parameters)
+ ws.send(json.dumps(data))
+ answer = ""
+ while True:
+ ret = json.loads(ws.recv())
+ try:
+ message_type = ret['choices'][0]['message_type']
+ except:
+ print(ret)
+ raise
+ if message_type == "START":
+ continue
+ elif message_type == "CHUNK":
+ print(ret['choices'][0]['message']['content'],end="",flush=True)
+ answer += ret['choices'][0]['message']['content']
+ yield answer, "", *tab_list, *json_list
+ elif message_type == "END":
+ break
+ elif message_type == "ERROR":
+ print(ret['choices'][0]['message']['content'])
+ break
+ elif message_type == "CONTEXT":
+ print()
+ print('contexts',ret)
+ sources = ret['choices'][0]["knowledge_sources"]
+ debug_info = ret['choices'][0]["debug_info"]
+ yield answer, sources, *render_debug_info(debug_info)
+ ws.close()
+
+def generate_answer_from_local(query_input, entry_type):
+ model_id = "internlm2-chat-7b"
+ endpoint_name = "instruct-internlm2-chat-7b-f7dc2"
+ # endpoint_name = "internlm2-chat-7b-2024-02-23-07-29-02-632"
+ # endpoint_name = "internlm2-chat-7b-4bits-2024-02-28-07-08-57-839"
+ rag_parameters=dict(
+ query_process_config = {
+ "conversation_query_rewrite_config":{
+ "model_id":model_id,
+ "endpoint_name":endpoint_name
+ },
+ "translate_config":{
+ "model_id":model_id,
+ "endpoint_name": endpoint_name
+ }
+ },
+ retriever_config = {
+ "chunk_num": 2,
+ "using_whole_doc": False,
+ "enable_reranker": True,
+ "retriever_top_k": 5,
+ "workspace_ids": ["aos_index_mkt_faq_qq_m3", "aos_index_acts_qd_m3"]
},
generator_llm_config ={
"model_kwargs":{
@@ -103,9 +193,19 @@ def get_answer(query_input, entry_type):
"endpoint_name": endpoint_name,
"context_num": 1
})
- answer, source, debug_info = generate_answer(
- query_input, type=entry_type, rag_parameters=rag_parameters
- )
+ sources = []
+ debug_info = []
+ answer, sources, debug_info = generate_answer(
+ query=query_input, type=entry_type, rag_parameters=rag_parameters)
+ return answer, sources, *render_debug_info(debug_info)
+
+def generate_func(api_type, url_input, query_input, entry_type):
+ if api_type == "local":
+ return generate_answer_from_local(query_input, entry_type)
+ elif api_type == "cloud":
+ return generate_answer_from_ws(url_input, query_input, entry_type)
+
+def render_debug_info(debug_info):
tab_list = []
json_list = []
json_count = 0
@@ -123,13 +223,7 @@ def get_answer(query_input, entry_type):
tab_list.append(gr.Tab(visible=False))
for i in range(max_debug_block-json_count):
json_list.append(gr.JSON(value=["dummy"], visible=False))
- return (
- answer,
- source,
- *tab_list,
- *json_list,
- )
-
+ return *tab_list, *json_list
def invoke_etl_online(
url_input, s3_bucket_chunk_input, s3_prefix_chunk_input, need_split_dropdown
@@ -246,9 +340,13 @@ def load_by_langchain(s3_bucket_dropdown, s3_prefix_compare):
url_input = gr.Text(
label="Url, eg. https://f2zgexpo47.execute-api.us-east-1.amazonaws.com/v1/"
)
+ websocket_input = gr.Text(
+ label="Websocket, eg. wss://5nnxrqr4ya.execute-api.cn-north-1.amazonaws.com.cn/prod/"
+ )
with gr.Tab("Chat"):
+ api_type = gr.Dropdown(label="API", choices=["local", "cloud"], value="local")
+ entry_input = gr.Dropdown(label="Entry", choices=["common", "market_chain_core"], value="market_chain_core")
query_input = gr.Text(label="Query")
- entry_input = gr.Dropdown(label="Entry", choices=["common", "market_chain"], value="common")
answer_output = gr.Text(label="Anwser", show_label=True)
sources_output = gr.Text(label="Sources", show_label=True)
tab_list = []
@@ -278,8 +376,8 @@ def load_by_langchain(s3_bucket_dropdown, s3_prefix_compare):
answer_btn = gr.Button(value="Answer")
context = None
answer_btn.click(
- get_answer,
- inputs=[query_input, entry_input],
+ generate_func,
+ inputs=[api_type, websocket_input, query_input, entry_input],
outputs=[
answer_output,
sources_output,
@@ -288,14 +386,14 @@ def load_by_langchain(s3_bucket_dropdown, s3_prefix_compare):
],
)
- with gr.Accordion("RawDataDebugInfo", open=False):
- raw_data = gr.JSON()
- check_btn = gr.Button(value="Check")
- check_btn.click(check_data, inputs=[url_input], outputs=[raw_data])
+ # with gr.Accordion("RawDataDebugInfo", open=False):
+ # raw_data = gr.JSON()
+ # check_btn = gr.Button(value="Check")
+ # check_btn.click(check_data, inputs=[url_input], outputs=[raw_data])
gr.Examples(
examples=text,
- inputs=[query_input],
- fn=generate_answer,
+ inputs=[websocket_input, query_input, entry_input],
+ fn=generate_func,
cache_examples=False,
)
with gr.Tab("Data Process Offline"):
@@ -432,8 +530,7 @@ def update_s3_prefix_dropdown(s3_bucket, s3_prefix):
outputs=[solution_md],
)
-
# load_raw_data()
if __name__ == "__main__":
demo.queue()
- demo.launch(server_name="0.0.0.0", share=False, server_port=3309)
+ demo.launch(server_name="0.0.0.0", share=False, server_port=3309)
\ No newline at end of file
diff --git a/source/lambda/executor/test/executor_local_test.py b/source/lambda/executor/test/executor_local_test.py
index 8104d69d..70d77746 100644
--- a/source/lambda/executor/test/executor_local_test.py
+++ b/source/lambda/executor/test/executor_local_test.py
@@ -33,8 +33,6 @@
# print(region)
import main
import os
-aos_index_dict = json.loads(os.environ.get("aos_index_dict", ""))
-print(f"aos index {aos_index_dict}")
class DummyWebSocket:
def post_to_connection(self,ConnectionId,Data):
@@ -51,12 +49,16 @@ def post_to_connection(self,ConnectionId,Data):
print(ret['choices'][0]['message']['content'])
return
elif message_type == "CONTEXT":
- print('sources: ',ret['choices'][0]['knowledge_sources'])
+ print('knowledge_sources num',ret['choices'][0]['knowledge_sources'])
+ if ret['choices'][0].get('contexts'):
+ print('contexts num',len(ret['choices'][0].get('contexts')))
+ print('contexts avg len: ', sum(len(i) for i in ret['choices'][0]['contexts'])/len(ret['choices'][0]['contexts']))
+ # print('sources: ',ret['choices'][0]['contexts'])
main.ws_client = DummyWebSocket()
def generate_answer(query,
- temperature=0.7,
+ # temperature=0.7,
enable_debug=True,
retrieval_only=False,
type="market_chain",
@@ -74,7 +76,7 @@ def generate_answer(query,
"content": query
}
],
- "temperature": temperature,
+ # "temperature": temperature,
# "enable_debug": enable_debug,
# "retrieval_only": retrieval_only,
# "retriever_index": retriever_index,
@@ -346,15 +348,20 @@ def test_baichuan_model():
def test_internlm_model():
session_id=f'test_{time.time()}'
- endpoint_name = 'internlm2-chat-7b-4bits-2024-02-28-07-08-57-839'
- model_id = "internlm2-chat-7b"
+ # endpoint_name = 'internlm2-chat-7b-4bits-2024-02-28-07-08-57-839'
+ # model_id = "internlm2-chat-7b"
+
+ endpoint_name = 'internlm2-chat-20b-4bits-2024-02-29-05-37-42-885'
+ model_id = "internlm2-chat-20b"
+
rag_parameters = {
+ "get_contexts":True,
"retriever_config":{
- "retriever_top_k": 20,
- "chunk_num": 2,
- "using_whole_doc": True,
- "reranker_top_k": 10,
- "q_q_match_threshold": 0.9
+ "retriever_top_k": 1,
+ "chunk_num": 2,
+ "using_whole_doc": True,
+ "reranker_top_k": 10,
+ "enable_reranker": True
},
"query_process_config":{
"conversation_query_rewrite_config":{
@@ -364,6 +371,10 @@ def test_internlm_model():
"translate_config":{
"model_id":model_id,
"endpoint_name": endpoint_name
+ },
+ "stepback_config":{
+ "model_id":model_id,
+ "endpoint_name": endpoint_name
}
},
"intent_config": {
@@ -375,7 +386,7 @@ def test_internlm_model():
"max_new_tokens": 2000,
"temperature": 0.1,
"top_p": 0.9,
- 'repetition_penalty':1.2
+ # 'repetition_penalty':1.1
},
"model_id": model_id,
"endpoint_name": endpoint_name,
@@ -384,6 +395,22 @@ def test_internlm_model():
}
qq_match_test()
+ generate_answer(
+ "AWS支持上海region吗?",
+ model="auto",
+ type="market_chain",
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+ # print(sfg)
+ generate_answer(
+ "介绍一下Amazon EC2",
+ model="auto",
+ type="market_chain",
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+ # print(xfg)
generate_answer(
"什么是Amazon bedrock?",
model="auto",
@@ -392,6 +419,56 @@ def test_internlm_model():
rag_parameters=rag_parameters
)
+ generate_answer(
+ "《夜曲》是谁演唱的?",
+ session_id=session_id,
+ model="chat",
+ type="market_chain",
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+ generate_answer(
+ "他还有哪些其他歌曲?",
+ session_id=session_id,
+ model="chat",
+ type="market_chain",
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+
+ r = generate_answer(
+ "解释一下“温故而知新”",
+ model="auto",
+ type="market_chain",
+ stream=False,
+ rag_parameters=rag_parameters
+ )
+ print(r[0])
+
+
+def test_internlm_model_mkt():
+ session_id=f'test_{time.time()}'
+ # endpoint_name = 'internlm2-chat-7b-4bits-2024-02-28-07-08-57-839'
+ # model_id = "internlm2-chat-7b"
+ endpoint_name = 'internlm2-chat-20b-4bits-2024-03-04-06-32-53-653'
+ model_id = "internlm2-chat-20b"
+
+ os.environ['llm_model_id'] = model_id
+ os.environ['llm_model_endpoint_name'] = endpoint_name
+
+ rag_parameters = {
+ "get_contexts":True,
+ }
+
+ qq_match_test()
+ generate_answer(
+ "AWS支持上海region吗?",
+ model="auto",
+ type="market_chain",
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+ # print(sfg)
generate_answer(
"介绍一下Amazon EC2",
model="auto",
@@ -399,7 +476,14 @@ def test_internlm_model():
stream=True,
rag_parameters=rag_parameters
)
- print(sdg)
+ # print(xfg)
+ generate_answer(
+ "什么是Amazon bedrock?",
+ model="auto",
+ type="market_chain",
+ stream=True,
+ rag_parameters=rag_parameters
+ )
generate_answer(
"《夜曲》是谁演唱的?",
@@ -428,6 +512,357 @@ def test_internlm_model():
print(r[0])
+
+def test_internlm_model_mkt_knowledge_entry_qq_match():
+ session_id=f'test_{time.time()}'
+ # endpoint_name = 'internlm2-chat-7b-4bits-2024-02-28-07-08-57-839'
+ # model_id = "internlm2-chat-7b"
+ endpoint_name = 'internlm2-chat-20b-4bits-2024-03-04-06-32-53-653'
+ model_id = "internlm2-chat-20b"
+ entry_type = "market_chain_knowledge"
+
+ os.environ['llm_model_id'] = model_id
+ os.environ['llm_model_endpoint_name'] = endpoint_name
+ # workspace_ids = ["aos_index_mkt_faq_qq_m3", "aos_index_acts_qd_m3", "aos_index_mkt_faq_qd_m3"]
+
+ questions = [
+ "能否通过JDBC连接到RDS for PostgreSQL? 有相关的指导吗?",
+ "如何解决切换RI后网速变慢?",
+ "如何升级EC2配置不改变IP",
+ "如何/怎么关停账号",
+ "请问怎么关闭账号?",
+ "个人能否注册账号?",
+ "怎么开具发票?",
+ "怎么开发票?",
+ "使用CDN服务要备案吗?"
+ ]
+ for question in questions:
+ generate_answer(
+ question,
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters={
+ "get_contexts":True,
+ "retriever_config":{
+ # "qq_config": {
+ # "qq_match_threshold": 0.8,
+ # },
+ # "qd_config":{
+ # "qd_match_threshold": 2,
+ # # "using_whole_doc": True
+ # },
+ # "workspace_ids": workspace_ids
+ }
+ }
+ )
+
+
+def test_internlm_model_mkt_knowledge_entry():
+ session_id=f'test_{time.time()}'
+ # endpoint_name = 'internlm2-chat-7b-4bits-2024-02-28-07-08-57-839'
+ # model_id = "internlm2-chat-7b"
+ endpoint_name = 'internlm2-chat-20b-4bits-2024-03-04-06-32-53-653'
+ model_id = "internlm2-chat-20b"
+ entry_type = "market_chain"
+
+ os.environ['llm_model_id'] = model_id
+ os.environ['llm_model_endpoint_name'] = endpoint_name
+ # workspace_ids = ["aos_index_mkt_faq_qq","aos_index_acts_qd"]
+ # workspace_ids = ["aos_index_mkt_faq_qq_m3", "aos_index_acts_qd_m3", "aos_index_mkt_faq_qd_m3"]
+
+ rag_parameters={
+ "get_contexts":True,
+ "session_id":session_id,
+ "retriever_config":{
+ # "qq_config": {
+ # "q_q_match_threshold": 0.8,
+ # },
+ # "qd_config":{
+ # "qd_match_threshold": 2,
+ # # "reranker_type": "bge_reranker",
+ # # "qd_match_threshold": 0.5,
+ # # "reranker_type": "no_reranker"
+
+ # # "using_whole_doc": True
+ # },
+ # "workspace_ids": workspace_ids
+ }
+ }
+
+ generate_answer(
+ "AWS支持上海region吗?",
+ # "什么是日志通",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+
+ generate_answer(
+ "亚马逊云科技有上海区域吗?",
+ # "AWS支持上海region吗?",
+ # "什么是日志通",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+
+ print(sfg)
+
+ generate_answer(
+ "亚马逊云科技中国区域免费套餐有哪几种不同类型的优惠?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+
+ generate_answer(
+ "怎么开发票?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+
+
+ generate_answer(
+ "日志通是什么?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+
+
+
+ generate_answer(
+ "Amazon Lambda函数是什么?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters={
+ "get_contexts":True,
+ "retriever_config":{
+ "qq_config": {
+ "q_q_match_threshold": 0.9,
+ },
+ "qd_config":{
+ "qd_match_threshold": 2,
+ "using_whole_doc": True
+ },
+ "workspace_ids": ["aos_index_mkt_faq_qq","aos_index_acts_qd"]
+
+ }
+ }
+ )
+
+ generate_answer(
+ "今天是几月几号?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters={
+ "get_contexts":True,
+ "retriever_config":{
+ "qq_config": {
+ "q_q_match_threshold": 0.9,
+ },
+ "qd_config":{
+ "qd_match_threshold": 2,
+ "using_whole_doc": True
+ },
+ "workspace_ids": workspace_ids
+ }
+ }
+ )
+
+ generate_answer(
+ "我上一个问题是什么?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters= {
+ "get_contexts":True,
+ "retriever_config":{
+ "qq_config": {
+ "q_q_match_threshold": 0.9,
+ },
+ "qd_config":{
+ "qd_match_threshold": 2,
+ "using_whole_doc": True
+ },
+ "workspace_ids": workspace_ids
+
+ }
+ }
+ )
+ # qq_match_test()
+ generate_answer(
+ "AWS支持上海region吗?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters = {
+ "get_contexts":True,
+ "retriever_config":{
+ "qq_config": {
+ "q_q_match_threshold": 0.9,
+ },
+ "qd_config":{
+ "qd_match_threshold": 2,
+ "using_whole_doc": True
+ },
+ "workspace_ids": workspace_ids
+
+ }
+ }
+ )
+
+ print(f)
+
+ generate_answer(
+ "AWS支持上海region吗?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+
+ rag_parameters={
+ "get_contexts":True,
+ "retriever_config":{
+ "qq_config": {
+ "q_q_match_threshold": 0.9,
+ },
+ "qd_config":{
+ "qd_match_threshold": 2,
+ "using_whole_doc": True
+ },
+ "workspace_ids": ["aos_index_mkt_faq_qq","aos_index_acts_qd"]
+
+ }
+ }
+ )
+ print(sfg)
+ generate_answer(
+ "介绍一下Amazon EC2",
+ model="auto",
+ type=entry_type,
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+ # print(xfg)
+ generate_answer(
+ "什么是Amazon bedrock?",
+ model="auto",
+ type=entry_type,
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+
+ generate_answer(
+ "《夜曲》是谁演唱的?",
+ session_id=session_id,
+ model="chat",
+ type=entry_type,
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+ generate_answer(
+ "他还有哪些其他歌曲?",
+ session_id=session_id,
+ model="chat",
+ type=entry_type,
+ stream=True,
+ rag_parameters=rag_parameters
+ )
+
+ r = generate_answer(
+ "解释一下“温故而知新”",
+ model="auto",
+ type=entry_type,
+ stream=False,
+ rag_parameters=rag_parameters
+ )
+ print(r[0])
+
+
+
+def test_internlm_model_mkt_knowledge_entry_langgraph():
+ session_id=f'test_{time.time()}'
+ # endpoint_name = 'internlm2-chat-7b-4bits-2024-02-28-07-08-57-839'
+ # model_id = "internlm2-chat-7b"
+ endpoint_name = 'internlm2-chat-20b-4bits-2024-03-04-06-32-53-653'
+ model_id = "internlm2-chat-20b"
+ entry_type = "market_chain_knowledge_langgraph"
+
+ os.environ['llm_model_id'] = model_id
+ os.environ['llm_model_endpoint_name'] = endpoint_name
+ generate_answer(
+ "今天是几月几号?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters={
+ "session_id":session_id,
+ "get_contexts":True,
+ "retriever_config":{
+ "qq_config": {
+ "q_q_match_threshold": 0.9,
+ },
+ "qd_config":{
+ "qd_match_threshold": 2,
+ "using_whole_doc": True
+ },
+ "workspace_ids": ["aos_index_mkt_faq_qq","aos_index_acts_qd"]
+ }
+ }
+ )
+
+ generate_answer(
+ "日志通是什么?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters={
+ "session_id":session_id,
+ "get_contexts":True,
+ "retriever_config":{
+ "qq_config": {
+ "q_q_match_threshold": 0.9,
+ },
+ "qd_config":{
+ "qd_match_threshold": 2,
+ "using_whole_doc": True
+ },
+ "workspace_ids": ["aos_index_mkt_faq_qq","aos_index_acts_qd"]
+ }
+ }
+ )
+
+ generate_answer(
+ "AWS支持上海region吗?",
+ model="knowledge_qa",
+ type=entry_type,
+ stream=True,
+ rag_parameters={
+ "session_id":session_id,
+ "get_contexts":True,
+ "retriever_config":{
+ "qq_config": {
+ "q_q_match_threshold": 0.9,
+ },
+ "qd_config":{
+ "qd_match_threshold": 2,
+ "using_whole_doc": True
+ },
+ "workspace_ids": ["aos_index_mkt_faq_qq","aos_index_acts_qd"]
+ }
+ }
+ )
+
+
def market_summary_test():
session_id = f'test_{int(time.time())}'
generate_answer(
@@ -544,62 +979,6 @@ def market_deploy_test():
market_summary_test2()
-# def market_deploy_cn_test():
-# model_id = "internlm2-chat-7b"
-# endpoint_name = "internlm2-chat-7b-2024-02-23-07-29-02-632"
-# rag_parameters = {
-# "query_process_config":{
-# "conversation_query_rewrite_config":{
-# "model_id":model_id,
-# "endpoint_name":endpoint_name
-# },
-# "translate_config":{
-# "model_id":model_id,
-# "endpoint_name": endpoint_name
-# }
-# },
-# "intent_config": {
-# "model_id": model_id,
-# "endpoint_name": endpoint_name
-# },
-# "generator_llm_config":{
-# "model_kwargs":{
-# "max_new_tokens": 2000,
-# "temperature": 0.1,
-# "top_p": 0.9
-# },
-# "model_id": model_id,
-# "endpoint_name": endpoint_name,
-# "context_num": 1
-# }
-# }
-# generate_answer(
-# "什么是Amazon Bedrock",
-# model="auto",
-# stream=True,
-# type="market_chain",
-# rag_parameters=rag_parameters
-# )
-
-
-# session_id = f'test_{int(time.time())}'
-# generate_answer(
-# "《夜曲》是谁演唱的?",
-# session_id=session_id,
-# model="chat",
-# type="market_chain",
-# stream=True,
-# rag_parameters=rag_parameters
-# )
-# generate_answer(
-# "他还有哪些其他歌曲?",
-# session_id=session_id,
-# model="chat",
-# type="market_chain",
-# stream=True,
-# rag_parameters=rag_parameters
-# )
-
if __name__ == "__main__":
# market_summary_test()
@@ -629,7 +1008,11 @@ def market_deploy_test():
# market_deploy_test()
# test_baichuan_model()
- test_internlm_model()
+ # market_summary_test2()
+ # test_internlm_model()
+ test_internlm_model_mkt_knowledge_entry()
+ # test_internlm_model_mkt_knowledge_entry_qq_match()
+ # test_internlm_model_mkt_knowledge_entry_langgraph()
# test_baichuan_model()
# market_deploy_test()
@@ -637,15 +1020,16 @@ def market_deploy_test():
# generate_answer(
# # "如何将Kinesis Data Streams配置为AWS Lambda的事件源?",
# # "Amazon EC2 提供了哪些功能来支持不同区域之间的数据恢复?",
- # "live chat",
+ # "什么是Amazon bedrock?",
# model="knowledge_qa",
# stream=True,
# type="market_chain",
# rag_parameters=dict(
- # retriver_config={
+ # get_contexts = True,
+ # retriever_config={
# "retriever_top_k": 1,
# "chunk_num": 2,
- # "using_whole_doc": True,
+ # "using_whole_doc": False,
# "reranker_top_k": 10,
# "enable_reranker": True
# },
diff --git a/source/lambda/executor/test/executor_local_test_cn.py b/source/lambda/executor/test/executor_local_test_cn.py
index 1df6de09..819c51d3 100644
--- a/source/lambda/executor/test/executor_local_test_cn.py
+++ b/source/lambda/executor/test/executor_local_test_cn.py
@@ -58,7 +58,7 @@ def post_to_connection(self,ConnectionId,Data):
main.ws_client = DummyWebSocket()
def generate_answer(query,
- temperature=0.7,
+ # temperature=0.7,
enable_debug=True,
retrieval_only=False,
type="market_chain",
@@ -77,7 +77,7 @@ def generate_answer(query,
"content": query
}
],
- "temperature": temperature,
+ # "temperature": temperature,
# "enable_debug": enable_debug,
# "retrieval_only": retrieval_only,
# "retriever_index": retriever_index,
@@ -546,8 +546,17 @@ def market_deploy_test():
def market_deploy_cn_test():
model_id = "internlm2-chat-7b"
- endpoint_name = "internlm2-chat-7b-2024-02-23-07-29-02-632"
+ # endpoint_name = "instruct-internlm2-chat-7b-f7dc2"
+ endpoint_name = "internlm2-chat-7b-llama-exl2-2024-03-04-01-49-46-235"
rag_parameters = {
+ "get_contexts":True,
+ "retriever_config":{
+ "retriever_top_k": 1,
+ "chunk_num": 2,
+ "using_whole_doc": True,
+ "reranker_top_k": 10,
+ "enable_reranker": True
+ },
"query_process_config":{
"conversation_query_rewrite_config":{
"model_id":model_id,
@@ -565,7 +574,7 @@ def market_deploy_cn_test():
"generator_llm_config":{
"model_kwargs":{
"max_new_tokens": 2000,
- "temperature": 0.1,
+ "temperature": 0.05,
"top_p": 0.9
},
"model_id": model_id,
@@ -573,17 +582,42 @@ def market_deploy_cn_test():
"context_num": 1
}
}
+
+ # generate_answer(
+ # "什么是Amazon Bedrock",
+ # model="auto",
+ # stream=True,
+ # type="market_chain",
+ # get_contexts=True,
+ # rag_parameters=rag_parameters
+ # )
+ # #
+
generate_answer(
- "什么是Amazon Bedrock",
+ "AWS支持上海region吗?",
model="auto",
+ type="market_chain",
stream=True,
+ rag_parameters=rag_parameters
+ )
+ print(sfg)
+ generate_answer(
+ "介绍一下Amazon EC2",
+ model="auto",
type="market_chain",
- get_contexts=True,
+ stream=True,
rag_parameters=rag_parameters
)
- print(xfg)
+ generate_answer(
+ "什么是Amazon Bedrock",
+ model="auto",
+ stream=True,
+ type="market_chain",
+ get_contexts=True,
+ rag_parameters=rag_parameters
+ )
session_id = f'test_{int(time.time())}'
generate_answer(
diff --git a/source/lambda/executor/utils/aos_utils.py b/source/lambda/executor/utils/aos_utils.py
index 19fedb66..b6a554b1 100644
--- a/source/lambda/executor/utils/aos_utils.py
+++ b/source/lambda/executor/utils/aos_utils.py
@@ -2,9 +2,12 @@
import boto3
import requests
import os
+import threading
from requests_aws4auth import AWS4Auth
from opensearchpy import OpenSearch, RequestsHttpConnection
+open_search_client_lock = threading.Lock()
+
credentials = boto3.Session().get_credentials()
region = boto3.Session().region_name
@@ -22,10 +25,19 @@ def _import_not_found_error():
return NotFoundError
class LLMBotOpenSearchClient:
+ instance = None
+ def __new__(cls,host):
+ with open_search_client_lock:
+ if cls.instance is not None and cls.instance.host == host:
+ return cls.instance
+ obj = object.__new__(cls)
+ cls.instance = obj
+ return obj
def __init__(self, host):
"""
Initialize OpenSearch client using OpenSearch Endpoint
"""
+ self.host = host
self.client = OpenSearch(
hosts = [{
'host': host.replace("https://", ""),
@@ -185,4 +197,6 @@ def search(self, index_name, query_type, query_term, field: str = "text", size:
body=query,
index=index_name
)
- return response
\ No newline at end of file
+ return response
+
+
diff --git a/source/lambda/executor/utils/constant.py b/source/lambda/executor/utils/constant.py
index a616038f..62205f84 100644
--- a/source/lambda/executor/utils/constant.py
+++ b/source/lambda/executor/utils/constant.py
@@ -5,6 +5,8 @@ class EntryType(Enum):
DGR = "dgr"
MARKET = "market"
MARKET_CHAIN = "market_chain"
+ MARKET_CHAIN_CORE = "market_chain_core"
+ MARKET_CHAIN_KNOWLEDGE = "market_chain_knowledge"
QQ_RETRIEVER = "qq_retriever"
QD_RETRIEVER = "qd_retriever"
MARKET_CONVERSATION_SUMMARY = "market_conversation_summary"
@@ -16,21 +18,42 @@ def has_value(cls, value):
class IntentType(Enum):
KNOWLEDGE_QA = "knowledge_qa"
CHAT = "chat"
+ MARKET_EVENT = 'market_event'
STRICT_QQ = "strict_q_q"
AUTO = "auto"
@classmethod
def has_value(cls, value):
return value in cls._value2member_map_
+class RerankerType(Enum):
+ BGE_RERANKER = "bge_reranker"
+ BGE_M3_RERANKER = "bge_m3_colbert"
+ BYPASS = "no_reranker"
+ @classmethod
+ def has_value(cls, value):
+ return value in cls._value2member_map_
+
+
+# LLM chain typs
QUERY_TRANSLATE_TYPE = "query_translate" # for query translate purpose
INTENT_RECOGNITION_TYPE = "intent_recognition" # for intent recognition
AWS_TRANSLATE_SERVICE_MODEL_ID = "Amazon Translate"
QUERY_TRANSLATE_IDENTITY_TYPE = "identity"
QUERY_REWRITE_TYPE = "query_rewrite"
+HYDE_TYPE = "hyde"
CONVERSATION_SUMMARY_TYPE = "conversation_summary"
MKT_CONVERSATION_SUMMARY_TYPE = "mkt_conversation_summary"
-
+STEPBACK_PROMPTING_TYPE = "stepback_prompting"
HUMAN_MESSAGE_TYPE = 'human'
AI_MESSAGE_TYPE = 'ai'
-SYSTEM_MESSAGE_TYPE = 'system'
\ No newline at end of file
+SYSTEM_MESSAGE_TYPE = 'system'
+
+
+
+class StreamMessageType:
+ START = "START"
+ END = "END"
+ ERROR = "ERROR"
+ CHUNK = "CHUNK"
+ CONTEXT = "CONTEXT"
\ No newline at end of file
diff --git a/source/lambda/executor/utils/content_filter_utils/content_filters.py b/source/lambda/executor/utils/content_filter_utils/content_filters.py
new file mode 100644
index 00000000..65644599
--- /dev/null
+++ b/source/lambda/executor/utils/content_filter_utils/content_filters.py
@@ -0,0 +1,60 @@
+import os
+import csv
+from typing import Iterable,Union
+abs_dir = os.path.dirname(__file__)
+
+
+class ContentFilterBase:
+ def filter_sentence(self,sentence:str):
+ raise NotImplementedError
+
+
+class MarketContentFilter(ContentFilterBase):
+ def __init__(
+ self,
+ sensitive_words_path=os.path.join(abs_dir,'sensitive_word.csv')
+ ) -> None:
+ self.sensitive_words = self.create_sensitive_words(sensitive_words_path)
+
+ def create_sensitive_words(self,sensitive_words_path):
+ sensitive_words = set()
+ with open(sensitive_words_path, mode='r') as file:
+ csv_reader = csv.reader(file)
+ for row in csv_reader:
+ sensitive_words.add(row[0])
+ return sensitive_words
+
+ def filter_sensitive_words(self,sentence):
+ for sensitive_word in self.sensitive_words:
+ length = len(sensitive_word)
+ sentence = sentence.replace(sensitive_word, '*' * length)
+ return sentence
+
+ def rebranding_words(self,sentence:str):
+ rebranding_dict = {'AWS': 'Amazon Web Services'}
+ for key, value in rebranding_dict.items():
+ sentence = sentence.replace(key, value)
+ return sentence
+
+ def filter_sentence(self,sentence):
+ sentence = self.filter_sensitive_words(sentence)
+ sentence = self.rebranding_words(sentence)
+ return sentence
+
+
+def token_to_sentence_gen(answer:Iterable[str],stop_signals: Union[list[str],set[str]]):
+ accumulated_chunk_ans = ""
+ for ans in answer:
+ accumulated_chunk_ans += ans
+ if not (len(ans) > 0 and ans[-1] in stop_signals):
+ continue
+ yield accumulated_chunk_ans
+ accumulated_chunk_ans = ""
+
+ if accumulated_chunk_ans:
+ yield accumulated_chunk_ans
+
+
+def token_to_sentence_gen_market(answer:Iterable[str]):
+ stop_signals = {',', '.', '?', '!', ',', '。', '!', '?'}
+ return token_to_sentence_gen(answer,stop_signals)
diff --git a/source/lambda/executor/utils/content_filter_utils/sensitive_word.csv b/source/lambda/executor/utils/content_filter_utils/sensitive_word.csv
new file mode 100644
index 00000000..cf3dea48
--- /dev/null
+++ b/source/lambda/executor/utils/content_filter_utils/sensitive_word.csv
@@ -0,0 +1,401 @@
+斯捷潘纳克特
+斯捷潘纳克特
+斯捷潘纳克特
+汉肯德
+汉肯德
+汉肯德
+舒什
+舒什
+舒什
+舒希
+舒希
+舒希
+马尔图尼
+马尔图尼
+马尔图尼
+霍贾文德
+霍贾文德
+霍贾文德
+马尔塔克尔特
+马尔塔克尔特
+马尔塔克尔特
+阿格代雷
+阿格代雷
+阿格代雷
+纳布卢斯
+纳布卢斯
+纳布卢斯
+谢克赫姆
+谢克赫姆
+谢克赫姆
+示剑
+示剑
+示剑
+马其顿
+马其顿
+马其顿共和国
+马其顿共和国
+波斯湾
+阿拉伯海湾
+日本海
+库尔德斯坦
+伊拉克和黎凡特伊斯兰国
+伊拉克和沙姆伊斯兰国
+伊斯兰国
+达伊沙
+博科圣地
+青年党
+北塞浦路斯
+北塞浦路斯土耳其共和国
+塞浦路斯(土耳其占领)
+TRNC
+科索沃
+科索沃和梅托希亚自治省
+巴勒斯坦
+巴勒斯坦国
+以色列占领的巴勒斯坦领土
+巴勒斯坦领土
+巴勒斯坦权力机构
+西岸
+西岸
+以色列占领的西岸
+以色列占领的西岸
+犹地亚和撒玛利亚
+犹地亚和撒玛利亚
+加沙
+加沙地带
+东耶路撒冷
+东耶路撒冷
+耶路撒冷
+耶路撒冷
+以色列耶路撒冷
+以色列耶路撒冷
+巴勒斯坦耶路撒冷
+巴勒斯坦耶路撒冷
+圣城
+圣城
+德涅斯特河沿岸
+德涅斯特河东岸
+阿布哈兹
+阿布哈兹共和国
+南奥塞梯
+南奥塞梯共和国
+维吾尔斯坦
+东突厥斯坦
+东突
+中华民国
+ROC
+加泰罗尼亚
+加泰罗尼亚共和国
+索马里兰
+西撒哈拉
+阿拉伯撒哈拉民主共和国
+亚巴佐尼亚
+亚巴佐尼亚联邦共和国
+南喀麦隆
+阿扎瓦德
+泰米尔伊拉姆
+泰米尔伊拉姆猛虎组织
+国家革命阵线
+北大年府
+也拉府
+陶公府
+棉兰老穆斯林
+摩洛民族解放阵线
+摩洛伊斯兰解放阵线
+阿布沙耶夫
+卡利斯坦
+福克兰群岛
+福克兰岛
+马尔维纳斯群岛
+马尔维纳斯岛
+马尔维纳斯
+克里米亚
+克里米亚
+克里米亚
+克里米亚共和国
+克里米亚共和国
+克里米亚共和国
+克里米亚自治共和国
+克里米亚自治共和国
+克里米亚自治共和国
+戈兰高地
+戈兰
+以色列占领的戈兰高地
+独岛
+独岛
+竹岛
+卡拉巴赫
+卡拉巴克
+纳戈尔诺·卡拉巴赫
+纳戈尔诺·卡拉巴赫
+纳戈尔诺·卡拉巴赫
+纳戈尔诺-卡拉巴赫
+纳戈尔诺-卡拉巴赫
+纳戈尔诺-卡拉巴赫
+阿尔扎赫
+阿尔扎赫
+阿尔扎赫
+卡拉巴赫(亚美尼亚占领)
+卡拉巴赫(亚美尼亚占领)
+卡拉巴赫(亚美尼亚占领)
+纳戈尔诺-卡拉巴赫共和国
+纳戈尔诺-卡拉巴赫共和国
+纳戈尔诺-卡拉巴赫共和国
+纳戈尔诺·卡拉巴赫共和国
+纳戈尔诺·卡拉巴赫共和国
+纳戈尔诺·卡拉巴赫共和国
+尖阁列岛
+尖阁诸岛
+钓鱼岛
+钓鱼台
+斯普拉特利群岛
+斯普拉特利群岛
+南沙群岛
+南沙群岛
+帕拉塞尔群岛
+帕拉塞尔群岛
+帕拉塞尔群岛
+西沙群岛
+西沙群岛
+西沙群岛
+黄沙群岛
+黄沙群岛
+黄沙群岛
+佩雷希尔岛
+雷拉
+西属圭亚那
+克什米尔
+查谟和克什米尔
+阿扎德克什米尔
+锡亚琴冰川
+克勒青河谷
+克里青河谷
+克勒青地带
+克里青地带
+泛喀喇昆仑
+阿克赛钦
+阿鲁纳恰尔邦
+藏南
+大小通布岛
+通布岛
+通布群岛
+阿布穆萨岛
+伊图鲁普岛
+择捉岛
+库纳施尔岛
+国后岛
+色丹岛; 施科坦岛
+齿舞群岛
+赫巴马伊群岛
+南千岛群岛
+休达
+梅利利亚
+萨巴阿农场
+萨巴阿农场
+沙巴农场
+沙巴农场
+哈多夫
+哈多夫
+伊米亚岛
+伊米亚岛
+伊米亚岛
+卡尔达克岛
+卡尔达克岛
+卡尔达克岛
+哈塔伊
+红皮肤人
+红皮
+黑鬼
+女黑鬼
+奥里萨邦
+北安查尔邦
+北部特区
+西北边境省
+塞族共和国
+第聂伯罗彼得罗夫斯克
+第聂伯捷尔任斯克
+伊利乔夫斯克
+基洛沃格勒
+奥尔忠尼基泽
+阿顿木斯克
+捷尔任斯克
+克拉斯尼-莱曼
+切尔沃诺扎沃茨科耶
+斯坦利
+斯坦利
+阿根廷港
+阿根廷港
+阿劳坎人
+阿洛柯人
+霍屯督人
+布须曼人
+布希曼人
+布尔人
+拉普人
+拉普兰
+吉卜赛人
+吉普赛人
+尼泊尔王国
+阿拉伯利比亚民众国
+西萨摩亚
+北伊庇鲁斯
+安得拉邦
+塞尔维亚和黑山
+南阿塞拜疆
+伊朗阿塞拜疆
+摩尔多瓦语言
+摩尔多瓦语
+马其顿语言
+马其顿语
+巴伦西亚语言
+瓦伦西亚语言
+巴伦西亚语
+瓦伦西亚语
+贱民
+罗兴亚人
+塞尔维亚克拉伊纳
+纳萨尔派分子
+纳萨尔
+哥伦比亚革命武装力量
+萨帕塔民族解放军
+光辉道路
+光明之路
+巴拉圭人民军
+亚历山大大帝机场
+亚历山大大帝公路
+科西嘉
+巴斯克地区
+巴斯克区
+比亚法拉
+比亚法拉共和国
+马约特岛
+马约特
+哈拉伊卜三角区
+阿卜耶伊
+凯菲亚金吉
+南乔治亚岛和南桑威奇群岛
+南乔治亚岛
+南桑威奇群岛
+欧加登
+索马里地区
+爱斯基摩人
+屄
+逼
+傻屄
+傻逼
+木耳
+鸡巴
+肏
+肏屄
+吹箫
+娼妇
+婊子
+妈屄
+妈逼
+我操
+操你妈
+装屄
+骚逼
+狗肏的
+狗日的
+妈那个屄
+妈那个逼
+日你妈
+日你祖宗
+操你祖宗
+日你先人
+操你先人
+屁眼
+菊花
+屌
+鸡儿
+卵
+卵子
+屁
+马屁精
+打炮
+爆菊
+菊爆
+捡肥皂
+撸管
+打飞机
+手铳
+找五姑娘
+找五公子
+打手枪
+贱
+淫
+淫贱
+淫荡
+贱人
+淫棍
+荡妇
+色狼
+咸猪手
+色鬼
+色魔
+皮条客
+鸭子
+基佬
+鸨
+窑子
+他妈的
+你他妈
+奶奶的
+妈的
+卧槽
+草泥马
+装逼
+变态
+扒灰
+秃驴
+骚
+骚货
+二百五
+王八蛋
+王八
+乌龟
+龟公
+龟儿子
+龟孙子
+畜生
+杂种
+狗杂种
+狗娘养的
+狗腿子
+狗崽子
+兔崽子
+弱智
+白痴
+脑残
+狐狸精
+乡巴佬
+老子
+二愣子
+混蛋
+混账
+放屁
+放狗屁
+你妈
+尼玛
+叫兽
+野战
+后庭
+奸夫淫妇
+奸夫
+淫妇
+卖肉的
+卖肉
+舔菊
+舔阴
+SB
+TMD
+肥猪
+四眼狗
+小白脸
+二奶
+疯子
+小三
+饭桶
+草包
\ No newline at end of file
diff --git a/source/lambda/executor/utils/context_utils.py b/source/lambda/executor/utils/context_utils.py
new file mode 100644
index 00000000..138b8558
--- /dev/null
+++ b/source/lambda/executor/utils/context_utils.py
@@ -0,0 +1,75 @@
+import logging
+from langchain.docstore.document import Document
+import os
+
+logger = logging.getLogger('context_utils')
+logger.setLevel(logging.INFO)
+
+def contexts_trunc(docs: list[dict], context_num=2):
+ # print('docs len',len(docs))
+ docs = [doc for doc in docs[:context_num]]
+ # the most related doc will be placed last
+ docs.sort(key=lambda x: x["score"])
+ logger.info(f'max context score: {docs[-1]["score"]}')
+ # filter same docs
+ s = set()
+ context_strs = []
+ context_docs = []
+ context_sources = []
+ for doc in docs:
+ content = doc['page_content']
+ if content not in s:
+ context_strs.append(content)
+ s.add(content)
+ context_docs.append({
+ "doc": content,
+ "source": doc["source"],
+ "score": doc["score"]
+ })
+ context_sources.append(doc["source"])
+ # print(len(context_docs))
+ # print(sg)
+ return {
+ "contexts": context_strs,
+ "context_docs": context_docs,
+ "context_sources":context_sources
+ }
+
+
+
+def retriever_results_format(
+ docs:list[Document],
+ print_source=True,
+ print_content=os.environ.get('print_content',False)
+ ):
+ doc_dicts = []
+
+ for doc in docs:
+ doc_dicts.append({
+ "page_content": doc.page_content,
+ "score": doc.metadata["score"],
+ "source": doc.metadata["source"],
+ "answer": doc.metadata.get("answer",""),
+ "question": doc.metadata.get("question","")
+ })
+ if print_source:
+ source_strs = []
+ for doc_dict in doc_dicts:
+ content = ""
+ if print_content:
+ content = f', content: {doc_dict["page_content"]}'
+ source_strs.append(f'source: {doc_dict["source"]}, score: {doc_dict["score"]}{content}')
+ logger.info("retrieved sources:\n"+ '\n'.join(source_strs))
+ return doc_dicts
+
+def retriever_results_filter(doc_dicts:list[dict],threshold=-1):
+ results = []
+ for doc_dict in doc_dicts:
+ if doc_dict["score"] < threshold:
+ continue
+ results.append(doc_dict)
+ return results
+
+
+
+
\ No newline at end of file
diff --git a/source/lambda/executor/utils/ddb_utils.py b/source/lambda/executor/utils/ddb_utils.py
index eba704c6..3239c8bb 100644
--- a/source/lambda/executor/utils/ddb_utils.py
+++ b/source/lambda/executor/utils/ddb_utils.py
@@ -29,10 +29,12 @@ def __init__(
table_name: str,
session_id: str,
user_id: str,
+ client_type: str,
):
self.table = client.Table(table_name)
self.session_id = session_id
self.user_id = user_id
+ self.client_type = client_type
@property
def messages(self):
@@ -40,7 +42,7 @@ def messages(self):
response = None
try:
response = self.table.get_item(
- Key={"SessionId": self.session_id, "UserId": self.user_id}
+ Key={"SessionId": self.session_id, "UserId": self.user_id,}
)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
@@ -82,6 +84,7 @@ def add_message(self, message) -> None:
Item={
"SessionId": self.session_id,
"UserId": self.user_id,
+ "ClientType": self.client_type,
"StartTime": datetime.now().isoformat(),
"History": messages,
}
@@ -89,7 +92,9 @@ def add_message(self, message) -> None:
except ClientError as err:
print(f"Error adding message: {err}")
- def add_user_message(self, message_id, content, entry_type) -> None:
+ def add_user_message(
+ self, content, message_id, custom_message_id, entry_type
+ ) -> None:
"""Append the user message to the record in DynamoDB"""
message = {
"type": HUMAN_MESSAGE_TYPE,
@@ -98,6 +103,7 @@ def add_user_message(self, message_id, content, entry_type) -> None:
"content": content,
"additional_kwargs": {
"message_id": message_id,
+ "custom_message_id": custom_message_id,
"create_time": Decimal.from_float(time.time()),
"entry_type": entry_type,
},
@@ -106,7 +112,9 @@ def add_user_message(self, message_id, content, entry_type) -> None:
}
self.add_message(message)
- def add_ai_message(self, message_id, content, entry_type) -> None:
+ def add_ai_message(
+ self, content, message_id, custom_message_id, entry_type
+ ) -> None:
"""Append the ai message to the record in DynamoDB"""
message = {
"type": AI_MESSAGE_TYPE,
@@ -115,6 +123,7 @@ def add_ai_message(self, message_id, content, entry_type) -> None:
"content": content,
"additional_kwargs": {
"message_id": message_id,
+ "custom_message_id": custom_message_id,
"create_time": Decimal.from_float(time.time()),
"entry_type": entry_type,
},
diff --git a/source/lambda/executor/utils/embeddings_utils.py b/source/lambda/executor/utils/embeddings_utils.py
new file mode 100644
index 00000000..6f8f2b00
--- /dev/null
+++ b/source/lambda/executor/utils/embeddings_utils.py
@@ -0,0 +1,36 @@
+# embeddings
+import os
+import json
+import boto3
+from typing import List,Dict
+from langchain_community.embeddings.sagemaker_endpoint import (
+ SagemakerEndpointEmbeddings,
+)
+from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
+
+class BGEEmbeddingSagemakerEndpoint:
+ class vectorContentHandler(EmbeddingsContentHandler):
+ content_type = "application/json"
+ accepts = "application/json"
+
+ def transform_input(self, inputs: List[str], model_kwargs: Dict) -> bytes:
+ input_str = json.dumps({"inputs": inputs, **model_kwargs})
+ return input_str.encode("utf-8")
+
+ def transform_output(self, output: bytes) -> List[List[float]]:
+ response_json = json.loads(output.read().decode("utf-8"))
+ return response_json["sentence_embeddings"]
+
+ def __new__(cls,endpoint_name,region_name=os.environ['AWS_REGION']):
+ client = boto3.client(
+ "sagemaker-runtime",
+ region_name=region_name
+ )
+ content_handler = cls.vectorContentHandler()
+ embedding = SagemakerEndpointEmbeddings(
+ client=client,
+ endpoint_name=endpoint_name,
+ content_handler=content_handler
+ )
+ return embedding
+
diff --git a/source/lambda/executor/utils/executor_entries/__init__.py b/source/lambda/executor/utils/executor_entries/__init__.py
new file mode 100644
index 00000000..77495eba
--- /dev/null
+++ b/source/lambda/executor/utils/executor_entries/__init__.py
@@ -0,0 +1,7 @@
+from .common_entry import main_chain_entry
+from .mkt_entry_core import market_chain_entry as market_chain_entry_core
+from .mkt_entry import market_chain_entry
+from .market_conversation_summary_entry import market_conversation_summary_entry
+from .retriever_entries import main_qd_retriever_entry,main_qq_retriever_entry,get_retriever_response
+from .mkt_knowledge_entry import market_chain_knowledge_entry
+from .mkt_knowledge_entry_langgraph import market_chain_knowledge_entry as market_chain_knowledge_entry_langgraph
\ No newline at end of file
diff --git a/source/lambda/executor/utils/executor_entries/common_entry.py b/source/lambda/executor/utils/executor_entries/common_entry.py
new file mode 100644
index 00000000..05f29809
--- /dev/null
+++ b/source/lambda/executor/utils/executor_entries/common_entry.py
@@ -0,0 +1,41 @@
+
+from .mkt_entry_core import get_qd_llm_chain
+from .. import parse_config
+
+def main_chain_entry(
+ query_input: str,
+ index: str,
+ stream=False,
+ event_body=None
+):
+ """
+ Entry point for the Lambda function.
+
+ :param query_input: The query input.
+ :param aos_index: The index of the AOS engine.
+
+ return: answer(str)
+ """
+ rag_config=parse_config.parse_rag_config(event_body)
+ debug_info = {
+ "query": query_input,
+ "query_parser_info": {},
+ "q_q_match_info": {},
+ "knowledge_qa_knn_recall": {},
+ "knowledge_qa_boolean_recall": {},
+ "knowledge_qa_combined_recall": {},
+ "knowledge_qa_cross_model_sort": {},
+ "knowledge_qa_llm": {},
+ "knowledge_qa_rerank": {},
+ }
+ contexts = []
+ sources = []
+ answer = ""
+ full_chain = get_qd_llm_chain(
+ [index], rag_config, stream
+ )
+ response = full_chain.invoke({"query": query_input, "debug_info": debug_info})
+ answer = response["answer"]
+ sources = response["context_sources"]
+ contexts = response["context_docs"]
+ return answer, sources, contexts, debug_info
\ No newline at end of file
diff --git a/source/lambda/executor/utils/executor_entries/dgr_entry.py b/source/lambda/executor/utils/executor_entries/dgr_entry.py
new file mode 100644
index 00000000..6060d5de
--- /dev/null
+++ b/source/lambda/executor/utils/executor_entries/dgr_entry.py
@@ -0,0 +1,491 @@
+import time
+import os
+import json
+import copy
+import traceback
+
+from ..logger_utils import logger
+from ..preprocess_utils import run_preprocess
+from ..sm_utils import SagemakerEndpointVectorOrCross
+from ..aos_utils import LLMBotOpenSearchClient
+from ..context_utils import contexts_trunc
+from ..langchain_utils import RunnableDictAssign
+from ..llm_utils import LLMChain
+from ..constant import IntentType, Type
+region = os.environ["AWS_REGION"]
+
+aos_endpoint = os.environ.get("aos_endpoint", "")
+aos_client = LLMBotOpenSearchClient(aos_endpoint)
+
+def parse_query(
+ query_input: str,
+ history: list,
+ zh_embedding_model_endpoint: str,
+ en_embedding_model_endpoint: str,
+ debug_info: dict,
+ region = os.environ["AWS_REGION"]
+):
+ start = time.time()
+ # concatenate query_input and history to unified prompt
+ query_knowledge = "".join([query_input] + [row[0] for row in history][::-1])
+ # get query embedding
+ parsed_query = run_preprocess(query_knowledge)
+ debug_info["query_parser_info"] = parsed_query
+ if parsed_query["query_lang"] == "zh":
+ parsed_query["zh_query"] = query_knowledge
+ parsed_query["en_query"] = parsed_query["translated_text"]
+ elif parsed_query["query_lang"] == "en":
+ parsed_query["zh_query"] = parsed_query["translated_text"]
+ parsed_query["en_query"] = query_knowledge
+ zh_query_similarity_embedding_prompt = parsed_query["zh_query"]
+ en_query_similarity_embedding_prompt = parsed_query["en_query"]
+ zh_query_relevance_embedding_prompt = (
+ "为这个句子生成表示以用于检索相关文章:" + parsed_query["zh_query"]
+ )
+ en_query_relevance_embedding_prompt = (
+ "Represent this sentence for searching relevant passages: "
+ + parsed_query["en_query"]
+ )
+ parsed_query["zh_query_similarity_embedding"] = SagemakerEndpointVectorOrCross(
+ prompt=zh_query_similarity_embedding_prompt,
+ endpoint_name=zh_embedding_model_endpoint,
+ region_name=region,
+ model_type="vector",
+ stop=None,
+ )
+ parsed_query["zh_query_relevance_embedding"] = SagemakerEndpointVectorOrCross(
+ prompt=zh_query_relevance_embedding_prompt,
+ endpoint_name=zh_embedding_model_endpoint,
+ region_name=region,
+ model_type="vector",
+ stop=None,
+ )
+ parsed_query["en_query_similarity_embedding"] = SagemakerEndpointVectorOrCross(
+ prompt=en_query_similarity_embedding_prompt,
+ endpoint_name=en_embedding_model_endpoint,
+ region_name=region,
+ model_type="vector",
+ stop=None,
+ )
+ parsed_query["en_query_relevance_embedding"] = SagemakerEndpointVectorOrCross(
+ prompt=en_query_relevance_embedding_prompt,
+ endpoint_name=en_embedding_model_endpoint,
+ region_name=region,
+ model_type="vector",
+ stop=None,
+ )
+ elpase_time = time.time() - start
+ logger.info(f"runing time of parse query: {elpase_time}s seconds")
+ return parsed_query
+
+def get_faq_answer(source, index_name):
+ opensearch_query_response = aos_client.search(
+ index_name=index_name,
+ query_type="basic",
+ query_term=source,
+ field="metadata.source",
+ )
+ for r in opensearch_query_response["hits"]["hits"]:
+ if r["_source"]["metadata"]["field"] == "answer":
+ return r["_source"]["content"]
+ return ""
+
+
+def get_faq_content(source, index_name):
+ opensearch_query_response = aos_client.search(
+ index_name=index_name,
+ query_type="basic",
+ query_term=source,
+ field="metadata.source",
+ )
+ for r in opensearch_query_response["hits"]["hits"]:
+ if r["_source"]["metadata"]["field"] == "all_text":
+ return r["_source"]["content"]
+ return ""
+
+
+def organize_faq_results(response, index_name):
+ """
+ Organize results from aos response
+
+ :param query_type: query type
+ :param response: aos response json
+ """
+ results = []
+ if not response:
+ return results
+ aos_hits = response["hits"]["hits"]
+ for aos_hit in aos_hits:
+ result = {}
+ try:
+ result["source"] = aos_hit["_source"]["metadata"]["source"]
+ result["score"] = aos_hit["_score"]
+ result["detail"] = aos_hit["_source"]
+ result["content"] = aos_hit["_source"]["content"]
+ result["answer"] = get_faq_answer(result["source"], index_name)
+ result["doc"] = get_faq_content(result["source"], index_name)
+ except:
+ logger.info("index_error")
+ logger.info(aos_hit["_source"])
+ continue
+ # result.update(aos_hit["_source"])
+ results.append(result)
+ return results
+
+def remove_redundancy_debug_info(results):
+ filtered_results = copy.deepcopy(results)
+ for result in filtered_results:
+ for field in list(result["detail"].keys()):
+ if field.endswith("embedding") or field.startswith("vector"):
+ del result["detail"][field]
+ return filtered_results
+
+
+
+def get_ug_content(source, index_name):
+ opensearch_query_response = aos_client.search(
+ index_name=index_name,
+ query_type="basic",
+ query_term=source,
+ field="metadata.source",
+ size=100,
+ )
+ for r in opensearch_query_response["hits"]["hits"]:
+ if r["_source"]["metadata"]["field"] == "all_text":
+ return r["_source"]["content"]
+ return ""
+
+
+def organize_ug_results(response, index_name):
+ """
+ Organize results from aos response
+
+ :param query_type: query type
+ :param response: aos response json
+ """
+ results = []
+ aos_hits = response["hits"]["hits"]
+ for aos_hit in aos_hits:
+ result = {}
+ result["source"] = aos_hit["_source"]["metadata"]["source"]
+ result["score"] = aos_hit["_score"]
+ result["detail"] = aos_hit["_source"]
+ result["content"] = aos_hit["_source"]["content"]
+ result["doc"] = get_ug_content(result["source"], index_name)
+ # result.update(aos_hit["_source"])
+ results.append(result)
+ return results
+
+
+def combine_recalls(opensearch_knn_respose, opensearch_query_response):
+ '''
+ filter knn_result if the result don't appear in filter_inverted_result
+ '''
+ knn_threshold = 0.2
+ inverted_theshold = 5.0
+ filter_knn_result = { item["content"] : [item['source'], item["score"], item["doc"]] for item in opensearch_knn_respose if item["score"]> knn_threshold }
+ filter_inverted_result = { item["content"] : [item['source'], item["score"], item["doc"]] for item in opensearch_query_response if item["score"]> inverted_theshold }
+
+ combine_result = []
+ for content, doc_info in filter_knn_result.items():
+ # if doc in filter_inverted_result.keys():
+ combine_result.append({ "content": content, "doc" : doc_info[2], "score" : doc_info[1], "source" : doc_info[0] })
+ return combine_result
+
+
+def q_q_match(
+ parsed_query,
+ debug_info,
+ aos_faq_index
+ ):
+ start = time.time()
+ opensearch_knn_results = []
+ opensearch_knn_response = aos_client.search(
+ index_name=aos_faq_index,
+ query_type="knn",
+ query_term=parsed_query["zh_query_similarity_embedding"],
+ field="embedding",
+ size=2,
+ )
+ opensearch_knn_results.extend(
+ organize_faq_results(opensearch_knn_response, aos_faq_index)
+ )
+ opensearch_knn_response = aos_client.search(
+ index_name=aos_faq_index,
+ query_type="knn",
+ query_term=parsed_query["en_query_similarity_embedding"],
+ field="embedding",
+ size=2,
+ )
+ opensearch_knn_results.extend(
+ organize_faq_results(opensearch_knn_response, aos_faq_index)
+ )
+ # logger.info(json.dumps(opensearch_knn_response, ensure_ascii=False))
+ elpase_time = time.time() - start
+ logger.info(f"runing time of opensearch_knn : {elpase_time}s seconds")
+ answer = None
+ sources = None
+ if len(opensearch_knn_results) > 0:
+ debug_info["q_q_match_info"] = remove_redundancy_debug_info(
+ opensearch_knn_results[:3]
+ )
+ if opensearch_knn_results[0]["score"] >= 0.9:
+ source = opensearch_knn_results[0]["source"]
+ answer = opensearch_knn_results[0]["answer"]
+ sources = [source]
+ return answer, sources
+ return answer, sources
+
+
+def get_relevant_documents_dgr(
+ parsed_query,
+ rerank_model_endpoint: str,
+ aos_faq_index: str,
+ aos_ug_index: str,
+ debug_info,
+):
+ # 1. get AOS knn recall
+ faq_result_num = 2
+ ug_result_num = 20
+ start = time.time()
+ opensearch_knn_results = []
+ opensearch_knn_response = aos_client.search(
+ index_name=aos_faq_index,
+ query_type="knn",
+ query_term=parsed_query["zh_query_relevance_embedding"],
+ field="embedding",
+ size=faq_result_num,
+ )
+ opensearch_knn_results.extend(
+ organize_faq_results(opensearch_knn_response, aos_faq_index)[:faq_result_num]
+ )
+ opensearch_knn_response = aos_client.search(
+ index_name=aos_faq_index,
+ query_type="knn",
+ query_term=parsed_query["en_query_relevance_embedding"],
+ field="embedding",
+ size=faq_result_num,
+ )
+ opensearch_knn_results.extend(
+ organize_faq_results(opensearch_knn_response, aos_faq_index)[:faq_result_num]
+ )
+ # logger.info(json.dumps(opensearch_knn_response, ensure_ascii=False))
+ faq_recall_end_time = time.time()
+ elpase_time = faq_recall_end_time - start
+ logger.info(f"runing time of faq recall : {elpase_time}s seconds")
+ filter = None
+ if parsed_query["is_api_query"]:
+ filter = [{"term": {"metadata.is_api": True}}]
+
+ opensearch_knn_response = aos_client.search(
+ index_name=aos_ug_index,
+ query_type="knn",
+ query_term=parsed_query["zh_query_relevance_embedding"],
+ field="embedding",
+ filter=filter,
+ size=ug_result_num,
+ )
+ opensearch_knn_results.extend(
+ organize_ug_results(opensearch_knn_response, aos_ug_index)[:ug_result_num]
+ )
+ opensearch_knn_response = aos_client.search(
+ index_name=aos_ug_index,
+ query_type="knn",
+ query_term=parsed_query["en_query_relevance_embedding"],
+ field="embedding",
+ filter=filter,
+ size=ug_result_num,
+ )
+ opensearch_knn_results.extend(
+ organize_ug_results(opensearch_knn_response, aos_ug_index)[:ug_result_num]
+ )
+
+ debug_info["knowledge_qa_knn_recall"] = remove_redundancy_debug_info(
+ opensearch_knn_results
+ )
+ ug_recall_end_time = time.time()
+ elpase_time = ug_recall_end_time - faq_recall_end_time
+ logger.info(f"runing time of ug recall: {elpase_time}s seconds")
+
+ # 2. get AOS invertedIndex recall
+ opensearch_query_results = []
+
+ # 3. combine these two opensearch_knn_response and opensearch_query_response
+ recall_knowledge = combine_recalls(opensearch_knn_results, opensearch_query_results)
+
+ rerank_pair = []
+ for knowledge in recall_knowledge:
+ # rerank_pair.append([parsed_query["query"], knowledge["content"]][:1024])
+ rerank_pair.append(
+ [parsed_query["en_query"], knowledge["content"]][: 1024 * 10]
+ )
+ en_score_list = json.loads(
+ SagemakerEndpointVectorOrCross(
+ prompt=json.dumps(rerank_pair),
+ endpoint_name=rerank_model_endpoint,
+ region_name=region,
+ model_type="rerank",
+ stop=None,
+ )
+ )
+ rerank_pair = []
+ for knowledge in recall_knowledge:
+ # rerank_pair.append([parsed_query["query"], knowledge["content"]][:1024])
+ rerank_pair.append(
+ [parsed_query["zh_query"], knowledge["content"]][: 1024 * 10]
+ )
+ zh_score_list = json.loads(
+ SagemakerEndpointVectorOrCross(
+ prompt=json.dumps(rerank_pair),
+ endpoint_name=rerank_model_endpoint,
+ region_name=region,
+ model_type="rerank",
+ stop=None,
+ )
+ )
+ rerank_knowledge = []
+ for knowledge, score in zip(recall_knowledge, zh_score_list):
+ # if score > 0:
+ new_knowledge = knowledge.copy()
+ new_knowledge["rerank_score"] = score
+ rerank_knowledge.append(new_knowledge)
+ for knowledge, score in zip(recall_knowledge, en_score_list):
+ # if score > 0:
+ new_knowledge = knowledge.copy()
+ new_knowledge["rerank_score"] = score
+ rerank_knowledge.append(new_knowledge)
+ rerank_knowledge.sort(key=lambda x: x["rerank_score"], reverse=True)
+ debug_info["knowledge_qa_rerank"] = rerank_knowledge
+
+ rerank_end_time = time.time()
+ elpase_time = rerank_end_time - ug_recall_end_time
+ logger.info(f"runing time of rerank: {elpase_time}s seconds")
+
+ return rerank_knowledge
+
+def dgr_entry(
+ rag_config,
+ session_id: str,
+ query_input: str,
+ history: list,
+ zh_embedding_model_endpoint: str,
+ en_embedding_model_endpoint: str,
+ cross_model_endpoint: str,
+ rerank_model_endpoint: str,
+ llm_model_endpoint: str,
+ aos_faq_index: str,
+ aos_ug_index: str,
+ enable_knowledge_qa: bool,
+ temperature: float,
+ enable_q_q_match: bool,
+ llm_model_id=None,
+ stream=False,
+):
+ """
+ Entry point for the Lambda function.
+
+ :param session_id: The ID of the session.
+ :param query_input: The query input.
+ :param history: The history of the conversation.
+ :param embedding_model_endpoint: The endpoint of the embedding model.
+ :param cross_model_endpoint: The endpoint of the cross model.
+ :param llm_model_endpoint: The endpoint of the language model.
+ :param llm_model_name: The name of the language model.
+ :param aos_faq_index: The faq index of the AOS engine.
+ :param aos_ug_index: The ug index of the AOS engine.
+ :param enable_knowledge_qa: Whether to enable knowledge QA.
+ :param temperature: The temperature of the language model.
+ :param stream(Bool): Whether to use llm stream decoding output.
+
+ return: answer(str)
+ """
+ debug_info = {
+ "query": query_input,
+ "query_parser_info": {},
+ "q_q_match_info": {},
+ "knowledge_qa_knn_recall": {},
+ "knowledge_qa_boolean_recall": {},
+ "knowledge_qa_combined_recall": {},
+ "knowledge_qa_cross_model_sort": {},
+ "knowledge_qa_llm": {},
+ "knowledge_qa_rerank": {},
+ }
+ contexts = []
+ sources = []
+ answer = ""
+ try:
+ # 1. parse query
+ parsed_query = parse_query(
+ query_input,
+ history,
+ zh_embedding_model_endpoint,
+ en_embedding_model_endpoint,
+ debug_info,
+ )
+ # 2. query question match
+ if enable_q_q_match:
+ answer, sources = q_q_match(parsed_query, debug_info)
+ if answer and sources:
+ return answer, sources, contexts, debug_info
+ # 3. recall and rerank
+ knowledges = get_relevant_documents_dgr(
+ parsed_query,
+ rerank_model_endpoint,
+ aos_faq_index,
+ aos_ug_index,
+ debug_info,
+ )
+ context_num = 6
+ sources = list(set([item["source"] for item in knowledges[:context_num]]))
+ contexts = knowledges[:context_num]
+ # 4. generate answer using question and recall_knowledge
+ # parameters = {"temperature": temperature}
+ # generate_input = dict(
+ # model_id=llm_model_id,
+ # query=query_input,
+ # contexts=knowledges[:context_num],
+ # history=history,
+ # region_name=region,
+ # model_kwargs=parameters,
+ # context_num=context_num,
+ # model_type="answer",
+ # llm_model_endpoint=llm_model_endpoint,
+ # stream=stream,
+ # )
+ # TODO fix bug
+ llm_start_time = time.time()
+ # llm_chain = get_rag_llm_chain(**generate_input)
+ # llm_chain.invoke()
+ generator_llm_config = rag_config['generator_llm_config']
+
+ llm_chain = RunnableDictAssign(contexts_trunc) | LLMChain.get_chain(
+ intent_type=IntentType.KNOWLEDGE_QA.value,
+ stream=stream,
+ # chat_history=rag_config['chat_history'],
+ **generator_llm_config
+ )
+
+ answer = llm_chain.invoke({
+ "chat_history":rag_config['chat_history'],
+ "query":query_input
+ })
+
+ # answer = llm_generate(**generate_input)
+ llm_end_time = time.time()
+ elpase_time = llm_end_time - llm_start_time
+ logger.info(f"runing time of llm: {elpase_time}s seconds")
+ # answer = ret["answer"]
+ debug_info["knowledge_qa_llm"] = answer
+ except Exception as e:
+ logger.info(f"Exception Query: {query_input}")
+ logger.info(f"{traceback.format_exc()}")
+ answer = ""
+
+ # 5. update_session
+ # start = time.time()
+ # update_session(session_id=session_id, chat_session_table=chat_session_table,
+ # question=query_input, answer=answer, knowledge_sources=sources)
+ # elpase_time = time.time() - start
+ # logger.info(f'runing time of update_session : {elpase_time}s seconds')
+
+ return answer, sources, contexts, debug_info
\ No newline at end of file
diff --git a/source/lambda/executor/utils/executor_entries/market_conversation_summary_entry.py b/source/lambda/executor/utils/executor_entries/market_conversation_summary_entry.py
new file mode 100644
index 00000000..361ab15a
--- /dev/null
+++ b/source/lambda/executor/utils/executor_entries/market_conversation_summary_entry.py
@@ -0,0 +1,58 @@
+import math
+import logging
+import json
+from langchain.schema.messages import (
+ HumanMessage,AIMessage,SystemMessage
+)
+
+from ..llm_utils import LLMChain
+from ..constant import MKT_CONVERSATION_SUMMARY_TYPE
+from ..serialization_utils import JSONEncoder
+from ..ddb_utils import DynamoDBChatMessageHistory,filter_chat_history_by_time
+from .. import parse_config
+logger = logging.getLogger('market_conversation_summary_entry')
+logger.setLevel(logging.INFO)
+
+def market_conversation_summary_entry(
+ messages:list[dict],
+ event_body=None,
+ stream=False
+ ):
+
+ config = parse_config.parse_market_conversation_summary_entry_config(event_body)
+ logger.info(f'market rag configs:\n {json.dumps(config,indent=2,ensure_ascii=False,cls=JSONEncoder)}')
+ if not config['chat_history']:
+ assert messages,messages
+ chat_history = []
+ for message in messages:
+ role = message['role']
+ content = message['content']
+ assert role in ['user','ai']
+ if role == 'user':
+ chat_history.append(HumanMessage(content=content))
+ else:
+ chat_history.append(AIMessage(content=content))
+ config['chat_history'] = chat_history
+
+ else:
+ # filter by the window time
+ time_window = config.get('time_window',{})
+ start_time = time_window.get('start_time',-math.inf)
+ end_time = time_window.get('end_time',math.inf)
+ assert isinstance(start_time, float) and isinstance(end_time, float), (start_time, end_time)
+ chat_history = config['chat_history']
+ chat_history = filter_chat_history_by_time(chat_history,start_time=start_time,end_time=end_time)
+ config['chat_history'] = chat_history
+ # rag_config['intent_config']['intent_type'] = IntentType.CHAT.value
+
+ # query_input = """请简要总结上述对话中的内容,每一个对话单独一个总结,并用 '- '开头。 每一个总结要先说明问题。\n"""
+ mkt_conversation_summary_config = config["mkt_conversation_summary_config"]
+ llm_chain = LLMChain.get_chain(
+ intent_type=MKT_CONVERSATION_SUMMARY_TYPE,
+ stream=stream,
+ **mkt_conversation_summary_config,
+ )
+ response = llm_chain.invoke({
+ "chat_history": config['chat_history'],
+ })
+ return response, [], {}, {}
\ No newline at end of file
diff --git a/source/lambda/executor/utils/executor_entries/mkt_entry.py b/source/lambda/executor/utils/executor_entries/mkt_entry.py
new file mode 100644
index 00000000..0dbc7b4d
--- /dev/null
+++ b/source/lambda/executor/utils/executor_entries/mkt_entry.py
@@ -0,0 +1,24 @@
+import json
+import logging
+from .mkt_entry_core import market_chain_entry as market_chain_entry_core
+from ..constant import AWS_TRANSLATE_SERVICE_MODEL_ID
+from .. import parse_config
+
+logger = logging.getLogger(__file__)
+logger.setLevel(logging.INFO)
+
+def market_chain_entry(
+ query_input: str,
+ stream=False,
+ manual_input_intent=None,
+ event_body=None,
+ message_id=None
+ ):
+ rag_config = parse_config.parse_mkt_entry_config(event_body)
+ return market_chain_entry_core(
+ query_input,
+ stream=stream,
+ manual_input_intent=manual_input_intent,
+ rag_config=rag_config,
+ message_id=message_id
+ )
\ No newline at end of file
diff --git a/source/lambda/executor/utils/executor_entries/mkt_entry_core.py b/source/lambda/executor/utils/executor_entries/mkt_entry_core.py
new file mode 100644
index 00000000..cd0f6ec3
--- /dev/null
+++ b/source/lambda/executor/utils/executor_entries/mkt_entry_core.py
@@ -0,0 +1,283 @@
+import json
+import os
+from functools import partial
+import copy
+import asyncio
+import boto3
+
+from langchain.retrievers.merger_retriever import MergerRetriever
+from langchain.retrievers import ContextualCompressionRetriever
+from langchain.schema.runnable import (
+ RunnableBranch,
+ RunnableLambda,
+ RunnableParallel,
+ RunnablePassthrough,
+)
+
+from ..retriever import (
+ QueryDocumentRetriever,
+ QueryQuestionRetriever,
+ index_results_format,
+)
+from ..serialization_utils import JSONEncoder
+from ..reranker import BGEReranker, BGEM3Reranker, MergeReranker
+from ..retriever import (
+ QueryDocumentRetriever,
+ QueryQuestionRetriever,
+ index_results_format,
+)
+
+from ..logger_utils import logger
+from ..langchain_utils import add_key_to_debug,chain_logger,RunnableDictAssign
+from ..context_utils import contexts_trunc
+from ..llm_utils import LLMChain
+from ..constant import IntentType, RerankerType
+from ..query_process_utils import get_query_process_chain
+from ..intent_utils import auto_intention_recoginition_chain
+from .. import parse_config
+from ..workspace_utils import WorkspaceManager
+
+zh_embedding_endpoint = os.environ.get("zh_embedding_endpoint", "")
+en_embedding_endpoint = os.environ.get("en_embedding_endpoint", "")
+workspace_table = os.environ.get("workspace_table", "")
+
+dynamodb = boto3.resource("dynamodb")
+workspace_table = dynamodb.Table(workspace_table)
+workspace_manager = WorkspaceManager(workspace_table)
+
+def return_strict_qq_result(x):
+ return {
+ "answer": json.dumps(
+ x["intent_info"]["strict_qq_intent_result"], ensure_ascii=False
+ ),
+ "sources": [],
+ "contexts": [],
+ "context_docs": [],
+ "context_sources": [],
+ }
+
+
+def get_qd_chain(
+ workspace_list, retriever_top_k=10, reranker_top_k=5, using_whole_doc=True, chunk_num=0, reranker_type=RerankerType.BYPASS
+):
+ retriever_list = [
+ QueryDocumentRetriever(workspace, using_whole_doc, chunk_num, retriever_top_k)
+ for workspace in workspace_list
+ ]
+ lotr = MergerRetriever(retrievers=retriever_list)
+ if reranker_type == RerankerType.BGE_RERANKER:
+ compressor = BGEReranker(top_n=reranker_top_k)
+ elif reranker_type == RerankerType.BGE_M3_RERANKER:
+ compressor = BGEM3Reranker(top_n=reranker_top_k)
+ else:
+ compressor = MergeReranker(top_n=reranker_top_k)
+ compression_retriever = ContextualCompressionRetriever(
+ base_compressor=compressor, base_retriever=lotr
+ )
+ qd_chain = RunnablePassthrough.assign(docs=compression_retriever)
+ return qd_chain
+
+def get_qq_chain(workspace_list, message_id=None, retriever_top_k=5):
+ retriever_list = [
+ QueryQuestionRetriever(workspace, size=retriever_top_k)
+ for workspace in workspace_list
+ ]
+ qq_chain = MergerRetriever(retrievers=retriever_list)
+ qq_chain = RunnablePassthrough.assign(qq_result=qq_chain)
+ qq_chain = chain_logger(qq_chain, 'qq_chain', message_id)
+ return qq_chain
+
+def get_qd_llm_chain(
+ workspace_list,
+ rag_config,
+ stream=False,
+ message_id=None,
+ # top_n=5
+):
+ using_whole_doc = rag_config['retriever_config']['using_whole_doc']
+ chunk_num = rag_config['retriever_config']['chunk_num']
+ retriever_top_k = rag_config['retriever_config']['retriever_top_k']
+ reranker_top_k = rag_config['retriever_config']['reranker_top_k']
+ reranker_type = rag_config['retriever_config']['reranker_type']
+
+ qd_chain = get_qd_chain(workspace_list, using_whole_doc=using_whole_doc,
+ chunk_num=chunk_num, retriever_top_k=retriever_top_k,
+ reranker_top_k=reranker_top_k, reranker_type=reranker_type)
+
+ generator_llm_config = rag_config['generator_llm_config']
+ # TODO opt with efficiency
+ context_num = generator_llm_config['context_num']
+ llm_chain = RunnableDictAssign(lambda x: contexts_trunc(x['docs'],context_num=context_num)) |\
+ RunnablePassthrough.assign(
+ answer=LLMChain.get_chain(
+ intent_type=IntentType.KNOWLEDGE_QA.value,
+ stream=stream,
+ **generator_llm_config
+ ),
+ chat_history=lambda x:rag_config['chat_history']
+ )
+
+ qd_llm_chain = chain_logger(qd_chain, 'qd_retriever', message_id) | chain_logger(llm_chain, 'llm_chain', message_id)
+ return qd_llm_chain
+
+def get_chat_llm_chain(
+ rag_config,
+ stream=False
+ ):
+
+ chat_llm_chain = LLMChain.get_chain(
+ intent_type=IntentType.CHAT.value,
+ stream=stream,
+ **rag_config['generator_llm_config']
+ ) | {
+ "answer": lambda x: x,
+ "sources": lambda x: [],
+ "contexts": lambda x: [],
+ "intent_type": lambda x: IntentType.CHAT.value,
+ "context_docs": lambda x: [],
+ "context_sources": lambda x: [],
+ }
+ return chat_llm_chain
+
+def market_chain_entry(
+ query_input: str,
+ stream=False,
+ manual_input_intent=None,
+ event_body=None,
+ rag_config=None,
+ message_id=None
+):
+ """
+ Entry point for the Lambda function.
+
+ :param query_input: The query input.
+ :param aos_index: The index of the AOS engine.
+ :param stream(Bool): Whether to use llm stream decoding output.
+ return: answer(str)
+ """
+ if rag_config is None:
+ rag_config = parse_config.parse_mkt_entry_core_config(event_body)
+
+ assert rag_config is not None
+
+ logger.info(f'market rag configs:\n {json.dumps(rag_config,indent=2,ensure_ascii=False,cls=JSONEncoder)}')
+ intent_type = rag_config['intent_config']['intent_type']
+
+ workspace_ids = rag_config["retriever_config"]["workspace_ids"]
+ qq_workspace_list = []
+ qd_workspace_list = []
+ for workspace_id in workspace_ids:
+ workspace = workspace_manager.get_workspace(workspace_id)
+ if not workspace or "index_type" not in workspace:
+ logger.warning(f"workspace {workspace_id} not found")
+ continue
+ if workspace["index_type"] == "qq":
+ qq_workspace_list.append(workspace)
+ else:
+ qd_workspace_list.append(workspace)
+
+ debug_info = {}
+ contexts = []
+ sources = []
+ answer = ""
+ intent_info = {
+ "manual_input_intent": manual_input_intent,
+ "strict_qq_intent_result": {},
+ }
+
+ # 1. Strict Query Question Intent
+ # 1.1. strict query question retrieval.
+ # strict_q_q_chain = get_strict_qq_chain(aos_index_mkt_qq)
+
+ # 2. Knowledge QA Intent
+ # 2.1 query question retrieval.
+ qq_chain = get_qq_chain(qq_workspace_list, message_id)
+
+ # 2.2 query document retrieval + LLM.
+ qd_llm_chain = get_qd_llm_chain(
+ qd_workspace_list,
+ rag_config,
+ stream,
+ message_id
+ )
+
+ # 2.3 query question router.
+ def qq_route(info):
+ for doc in info["qq_result"]:
+ if doc.metadata["score"] > rag_config["retriever_config"]["q_q_match_threshold"]:
+ output = {
+ "answer": doc.metadata["answer"],
+ "sources": doc.metadata["source"],
+ "contexts": [],
+ "context_docs": [],
+ "context_sources": [],
+ # "debug_info": lambda x: x["debug_info"],
+ }
+ logger.info('qq matched...')
+ info.update(output)
+ return info
+ return qd_llm_chain
+
+ qq_qd_llm_chain = qq_chain | RunnableLambda(qq_route)
+
+ # query process chain
+ query_process_chain = get_query_process_chain(
+ rag_config['chat_history'],
+ rag_config['query_process_config'],
+ message_id=message_id
+ )
+ # | add_key_to_debug(add_key='conversation_query_rewrite',debug_key="debug_info")
+ # | add_key_to_debug(add_key='query_rewrite',debug_key="debug_info")
+
+ # query_rewrite_chain = chain_logger(
+ # query_rewrite_chain,
+ # "query rewrite module"
+ # )
+ # intent recognition
+ # intent_recognition_chain = auto_intention_recoginition_chain(
+ # q_q_retriever_config={
+ # "index_q_q":aos_index_mkt_qq_name,
+ # 'lang':'zh',
+ # 'embedding_endpoint':zh_embedding_endpoint,
+ # "q_q_match_threshold": rag_config['retriever_config']['q_q_match_threshold']
+ # },
+ # intent_config=rag_config['intent_config'],
+ # message_id=message_id
+ # )
+
+ # intent_recognition_chain = chain_logger(
+ # intent_recognition_chain,
+ # 'intention module',
+ # log_output_template='intent chain output: {intent_type}',
+ # message_id=message_id
+ # )
+
+ qq_qd_llm_chain = chain_logger(
+ qq_qd_llm_chain,
+ 'retrieve module',
+ message_id=message_id
+ )
+
+ full_chain = query_process_chain | RunnableBranch(
+ (lambda x:x['intent_type'] == IntentType.KNOWLEDGE_QA.value, qq_qd_llm_chain),
+ (lambda x:x['intent_type'] == IntentType.STRICT_QQ.value, return_strict_qq_result),
+ # (lambda x:x['intent_type'] == IntentType.STRICT_QQ.value, strict_q_q_chain),
+ get_chat_llm_chain(rag_config=rag_config,stream=stream), # chat
+ )
+ # full_chain = intent_recognition_chain
+ # full_chain = RunnableLambda(route)
+ response = asyncio.run(full_chain.ainvoke(
+ {
+ "query": query_input,
+ "debug_info": debug_info,
+ "intent_type": intent_type,
+ "intent_info": intent_info,
+ "chat_history": rag_config['chat_history']
+ }
+ ))
+
+ answer = response["answer"]
+ sources = response["context_sources"]
+ contexts = response["context_docs"]
+
+ return answer, sources, contexts, debug_info
\ No newline at end of file
diff --git a/source/lambda/executor/utils/executor_entries/mkt_knowledge_entry.py b/source/lambda/executor/utils/executor_entries/mkt_knowledge_entry.py
new file mode 100644
index 00000000..dad66272
--- /dev/null
+++ b/source/lambda/executor/utils/executor_entries/mkt_knowledge_entry.py
@@ -0,0 +1,382 @@
+import logging
+import json
+import os
+import boto3
+import time
+from functools import partial
+from textwrap import dedent
+from langchain.schema.runnable import (
+ RunnableBranch,
+ RunnableLambda,
+ RunnableParallel,
+ RunnablePassthrough,
+)
+from langchain.retrievers import ContextualCompressionRetriever
+from langchain.retrievers.merger_retriever import MergerRetriever
+from ..intent_utils import IntentRecognitionAOSIndex
+from ..llm_utils import LLMChain
+from ..serialization_utils import JSONEncoder
+from ..langchain_utils import chain_logger,RunnableDictAssign,RunnableParallelAssign
+from ..constant import IntentType, CONVERSATION_SUMMARY_TYPE, RerankerType
+import asyncio
+
+from ..retriever import (
+ QueryDocumentRetriever,
+ QueryQuestionRetriever
+)
+from .. import parse_config
+from ..reranker import BGEReranker, MergeReranker, BGEM3Reranker
+from ..context_utils import contexts_trunc,retriever_results_format,retriever_results_filter
+from ..langchain_utils import RunnableDictAssign
+from ..preprocess_utils import is_api_query, language_check,query_translate,get_service_name
+from ..workspace_utils import WorkspaceManager
+
+logger = logging.getLogger('mkt_knowledge_entry')
+logger.setLevel(logging.INFO)
+
+zh_embedding_endpoint = os.environ.get("zh_embedding_endpoint", "")
+en_embedding_endpoint = os.environ.get("en_embedding_endpoint", "")
+workspace_table = os.environ.get("workspace_table", "")
+
+dynamodb = boto3.resource("dynamodb")
+workspace_table = dynamodb.Table(workspace_table)
+workspace_manager = WorkspaceManager(workspace_table)
+
+
+
+def mkt_fast_reply(
+ answer="很抱歉,我只能回答与亚马逊云科技产品和服务相关的咨询。",
+ fast_info=""
+ ):
+ output = {
+ "answer": answer,
+ "sources": [],
+ "contexts": [],
+ "context_docs": [],
+ "context_sources": []
+ }
+ logger.info(f'mkt_fast_reply: {fast_info}')
+ return output
+
+def market_chain_knowledge_entry(
+ query_input: str,
+ stream=False,
+ manual_input_intent=None,
+ event_body=None,
+ rag_config=None,
+ message_id=None
+):
+ """
+ Entry point for the Lambda function.
+
+ :param query_input: The query input.
+ :param aos_index: The index of the AOS engine.
+ :param stream(Bool): Whether to use llm stream decoding output.
+ return: answer(str)
+ """
+ if rag_config is None:
+ rag_config = parse_config.parse_mkt_entry_knowledge_config(event_body)
+
+ assert rag_config is not None
+
+ logger.info(f'market rag knowledge configs:\n {json.dumps(rag_config,indent=2,ensure_ascii=False,cls=JSONEncoder)}')
+
+ workspace_ids = rag_config["retriever_config"]["workspace_ids"]
+ qq_workspace_list = []
+ qd_workspace_list = []
+ for workspace_id in workspace_ids:
+ workspace = workspace_manager.get_workspace(workspace_id)
+ if not workspace or "index_type" not in workspace:
+ logger.warning(f"workspace {workspace_id} not found")
+ continue
+ if workspace["index_type"] == "qq":
+ qq_workspace_list.append(workspace)
+ else:
+ qd_workspace_list.append(workspace)
+
+ debug_info = {}
+ contexts = []
+ sources = []
+ answer = ""
+ intent_info = {
+ "manual_input_intent": manual_input_intent,
+ "strict_qq_intent_result": {},
+ }
+
+
+ ################################################################################
+ # step 1 conversation summary chain, rewrite query involve history conversation#
+ ################################################################################
+
+ conversation_query_rewrite_config = rag_config['query_process_config']['conversation_query_rewrite_config']
+ conversation_query_rewrite_result_key = conversation_query_rewrite_config['result_key']
+ cqr_llm_chain = LLMChain.get_chain(
+ intent_type=CONVERSATION_SUMMARY_TYPE,
+ **conversation_query_rewrite_config
+ )
+ cqr_llm_chain = RunnableBranch(
+ # single turn
+ (lambda x: not x['chat_history'],RunnableLambda(lambda x:x['query'])),
+ cqr_llm_chain
+ )
+
+ conversation_summary_chain = chain_logger(
+ RunnablePassthrough.assign(
+ **{conversation_query_rewrite_result_key:cqr_llm_chain}
+ # query=cqr_llm_chain
+ ),
+ "conversation_summary_chain",
+ log_output_template='conversation_summary_chain result: {conversation_query_rewrite}',
+ message_id=message_id
+ )
+
+ #######################
+ # step 2 query preprocess#
+ #######################
+ translate_config = rag_config['query_process_config']['translate_config']
+ translate_chain = RunnableLambda(
+ lambda x: query_translate(
+ x['query'],
+ lang=x['query_lang'],
+ translate_config=translate_config
+ )
+ )
+ lang_check_and_translate_chain = RunnablePassthrough.assign(
+ query_lang = RunnableLambda(lambda x:language_check(x['query']))
+ ) | RunnablePassthrough.assign(translated_text=translate_chain)
+
+ is_api_query_chain = RunnableLambda(lambda x:is_api_query(x['query']))
+ service_names_recognition_chain = RunnableLambda(lambda x:get_service_name(x['query']))
+
+ preprocess_chain = lang_check_and_translate_chain | RunnableParallelAssign(
+ is_api_query=is_api_query_chain,
+ service_names=service_names_recognition_chain
+ )
+
+ log_output_template=dedent("""
+ preprocess result:
+ query_lang: {query_lang}
+ translated_text: {translated_text}
+ is_api_query: {is_api_query}
+ service_names: {service_names}
+ """)
+ preprocess_chain = chain_logger(
+ preprocess_chain,
+ 'preprocess query chain',
+ message_id=message_id,
+ log_output_template=log_output_template
+ )
+
+ #####################################
+ # step 3.1 intent recognition chain #
+ #####################################
+ intent_recognition_index = IntentRecognitionAOSIndex(embedding_endpoint_name=zh_embedding_endpoint)
+ intent_index_ingestion_chain = chain_logger(
+ intent_recognition_index.as_ingestion_chain(),
+ "intent_index_ingestion_chain",
+ message_id=message_id
+ )
+ intent_index_check_exist_chain = RunnablePassthrough.assign(
+ is_intent_index_exist = intent_recognition_index.as_check_index_exist_chain()
+ )
+ intent_index_search_chain = chain_logger(
+ intent_recognition_index.as_search_chain(top_k=5),
+ "intent_index_search_chain",
+ message_id=message_id
+ )
+ intent_postprocess_chain = intent_recognition_index.as_intent_postprocess_chain(method='top_1')
+
+ intent_search_and_postprocess_chain = intent_index_search_chain | intent_postprocess_chain
+ intent_branch = RunnableBranch(
+ (lambda x: not x['is_intent_index_exist'], intent_index_ingestion_chain | intent_search_and_postprocess_chain),
+ intent_search_and_postprocess_chain
+ )
+ intent_recognition_chain = intent_index_check_exist_chain | intent_branch
+
+ ####################
+ # step 3.2 qq match#
+ ####################
+ qq_match_threshold = rag_config['retriever_config']['qq_config']['qq_match_threshold']
+ qq_retriver_top_k = rag_config['retriever_config']['qq_config']['retriever_top_k']
+ qq_query_key = rag_config['retriever_config']['qq_config']['query_key']
+ retriever_list = [
+ QueryQuestionRetriever(
+ workspace,
+ size=qq_retriver_top_k,
+ query_key=qq_query_key
+ )
+ for workspace in qq_workspace_list
+ ]
+ qq_chain = chain_logger(
+ MergerRetriever(retrievers=retriever_list) | \
+ RunnableLambda(retriever_results_format) |\
+ RunnableLambda(partial(
+ retriever_results_filter,
+ threshold=qq_match_threshold
+ ))
+ ,
+ 'qq_chain'
+ )
+
+ ############################
+ # step 4. qd retriever chain#
+ ############################
+ qd_config = rag_config['retriever_config']['qd_config']
+ using_whole_doc = qd_config['using_whole_doc']
+ context_num = qd_config['context_num']
+ retriever_top_k = qd_config['retriever_top_k']
+ # reranker_top_k = qd_config['reranker_top_k']
+ # enable_reranker = qd_config['enable_reranker']
+ reranker_type = rag_config['retriever_config']['qd_config']['reranker_type']
+ qd_query_key = rag_config['retriever_config']['qd_config']['query_key']
+ retriever_list = [
+ QueryDocumentRetriever(
+ workspace=workspace,
+ using_whole_doc=using_whole_doc,
+ context_num=context_num,
+ top_k=retriever_top_k,
+ query_key=qd_query_key
+ # "zh", zh_embedding_endpoint
+ )
+ for workspace in qd_workspace_list
+ ]
+
+ lotr = MergerRetriever(retrievers=retriever_list)
+ if reranker_type == RerankerType.BGE_RERANKER.value:
+ compressor = BGEReranker(query_key=qd_query_key)
+ elif reranker_type == RerankerType.BGE_M3_RERANKER.value:
+ compressor = BGEM3Reranker()
+ else:
+ compressor = MergeReranker()
+
+ compression_retriever = ContextualCompressionRetriever(
+ base_compressor=compressor, base_retriever=lotr
+ )
+ qd_chain = chain_logger(
+ RunnablePassthrough.assign(
+ docs=compression_retriever | RunnableLambda(retriever_results_format)
+ ),
+ "qd chain",
+ message_id=message_id
+ )
+
+ #####################
+ # step 5. llm chain #
+ #####################
+ generator_llm_config = rag_config['generator_llm_config']
+ context_num = generator_llm_config['context_num']
+ llm_chain = RunnableDictAssign(lambda x: contexts_trunc(x['docs'],context_num=context_num)) |\
+ RunnablePassthrough.assign(
+ answer=LLMChain.get_chain(
+ intent_type=IntentType.KNOWLEDGE_QA.value,
+ stream=stream,
+ **generator_llm_config
+ ),
+ chat_history=lambda x:rag_config['chat_history']
+ )
+
+ # llm_chain = chain_logger(llm_chain,'llm_chain', message_id=message_id)
+
+ ###########################
+ # step 6 synthesize chain #
+ ###########################
+
+ ######################
+ # step 6.1 rag chain #
+ ######################
+ qd_match_threshold = rag_config['retriever_config']['qd_config']['qd_match_threshold']
+ qd_fast_reply_branch = RunnablePassthrough.assign(
+ filtered_docs = RunnableLambda(lambda x: retriever_results_filter(x['docs'],threshold=qd_match_threshold))
+ ) | RunnableBranch(
+ (
+ lambda x: not x['filtered_docs'],
+ RunnableLambda(lambda x: mkt_fast_reply(fast_info="insufficient context"))
+ ),
+ llm_chain
+ )
+
+ rag_chain = chain_logger(
+ qd_chain | qd_fast_reply_branch,
+ 'rag chain'
+ )
+
+ ######################################
+ # step 6.2 fast reply based on intent#
+ ######################################
+ log_output_template=dedent("""
+ qq_result num: {qq_result_num}
+ intent recognition type: {intent_type}
+ """)
+ qq_and_intention_type_recognition_chain = chain_logger(
+ RunnableParallelAssign(
+ qq_result=qq_chain,
+ intent_type=intent_recognition_chain,
+ ) | RunnablePassthrough.assign(qq_result_num=lambda x:len(x['qq_result'])),
+ "intention module",
+ log_output_template=log_output_template,
+ message_id=message_id
+ )
+
+ allow_intents = [
+ IntentType.KNOWLEDGE_QA.value,
+ IntentType.MARKET_EVENT.value
+ ]
+ qq_and_intent_fast_reply_branch = RunnableBranch(
+ (lambda x: len(x['qq_result']) > 0,
+ RunnableLambda(
+ lambda x: mkt_fast_reply(
+ answer=sorted(x['qq_result'],key=lambda x:x['score'],reverse=True)[0]['answer'],
+ fast_info='qq matched'
+ ))
+ ),
+ (
+ lambda x: x['intent_type'] not in allow_intents,
+ RunnableLambda(lambda x: mkt_fast_reply(fast_info=f"unsupported intent type: {x['intent_type']}"))
+ ),
+ rag_chain
+ )
+
+ #######################
+ # step 6.3 full chain #
+ #######################
+
+ process_query_chain = conversation_summary_chain | preprocess_chain
+
+ process_query_chain = chain_logger(
+ process_query_chain,
+ "query process module",
+ message_id=message_id
+ )
+
+ qq_and_intent_fast_reply_branch = chain_logger(
+ qq_and_intent_fast_reply_branch,
+ "retrieve module",
+ message_id=message_id
+ )
+
+ full_chain = process_query_chain | qq_and_intention_type_recognition_chain | qq_and_intent_fast_reply_branch
+
+ full_chain = chain_logger(
+ full_chain,
+ 'full_chain'
+ )
+ start_time = time.time()
+
+ response = asyncio.run(full_chain.ainvoke(
+ {
+ "query": query_input,
+ "debug_info": debug_info,
+ # "intent_type": intent_type,
+ "intent_info": intent_info,
+ "chat_history": rag_config['chat_history'] if rag_config['use_history'] else [],
+ # "query_lang": "zh"
+ }
+ ))
+
+ print('invoke time',time.time()-start_time)
+
+ answer = response["answer"]
+ sources = response["context_sources"]
+ contexts = response["context_docs"]
+
+ return answer, sources, contexts, debug_info
\ No newline at end of file
diff --git a/source/lambda/executor/utils/executor_entries/mkt_knowledge_entry_langgraph.py b/source/lambda/executor/utils/executor_entries/mkt_knowledge_entry_langgraph.py
new file mode 100644
index 00000000..22e6a994
--- /dev/null
+++ b/source/lambda/executor/utils/executor_entries/mkt_knowledge_entry_langgraph.py
@@ -0,0 +1,435 @@
+import logging
+import json
+import os
+import boto3
+from functools import partial
+from textwrap import dedent
+from langchain.schema.runnable import (
+ RunnableBranch,
+ RunnableLambda,
+ RunnableParallel,
+ RunnablePassthrough,
+)
+from langchain.retrievers import ContextualCompressionRetriever
+from langchain.retrievers.merger_retriever import MergerRetriever
+
+from langgraph.graph import StateGraph, END
+from typing import TypedDict, Any
+
+from ..intent_utils import IntentRecognitionAOSIndex
+from ..llm_utils import LLMChain
+from ..serialization_utils import JSONEncoder
+from ..langchain_utils import chain_logger,RunnableDictAssign,RunnableParallelAssign
+from ..constant import IntentType, CONVERSATION_SUMMARY_TYPE
+import asyncio
+
+from ..retriever import (
+ QueryDocumentRetriever,
+ QueryQuestionRetriever,
+)
+from .. import parse_config
+from ..reranker import BGEReranker, MergeReranker
+from ..context_utils import contexts_trunc,retriever_results_format,retriever_results_filter
+from ..langchain_utils import RunnableDictAssign
+from ..preprocess_utils import is_api_query, language_check,query_translate,get_service_name
+from ..workspace_utils import WorkspaceManager
+
+logger = logging.getLogger('mkt_knowledge_entry')
+logger.setLevel(logging.INFO)
+
+zh_embedding_endpoint = os.environ.get("zh_embedding_endpoint", "")
+en_embedding_endpoint = os.environ.get("en_embedding_endpoint", "")
+workspace_table = os.environ.get("workspace_table", "")
+
+dynamodb = boto3.resource("dynamodb")
+workspace_table = dynamodb.Table(workspace_table)
+workspace_manager = WorkspaceManager(workspace_table)
+
+
+class AppState(TypedDict):
+ state: dict
+
+def mkt_fast_reply(state: AppState):
+ state_ori = state
+ state = state['state']
+ fast_info = state.get('fast_info',"")
+ answer = state.get(
+ "answer",
+ "很抱歉,我只能回答与亚马逊云科技产品和服务相关的咨询。"
+ )
+ output = {
+ "answer": answer,
+ "sources": [],
+ "contexts": [],
+ "context_docs": [],
+ "context_sources": []
+ }
+ logger.info(f'mkt_fast_reply: {fast_info}')
+ state_ori['state'] = output
+ return state_ori
+
+
+def conversation_query_rewrite(state: AppState):
+ state_ori = state
+ state = state['state']
+
+ rag_config = state['rag_config']
+ conversation_query_rewrite_config = rag_config['query_process_config']['conversation_query_rewrite_config']
+
+ cqr_llm_chain = LLMChain.get_chain(
+ intent_type=CONVERSATION_SUMMARY_TYPE,
+ **conversation_query_rewrite_config
+ )
+ conversation_summary_chain = chain_logger(
+ RunnableBranch(
+ (
+ lambda x: not x['chat_history'],
+ RunnableLambda(lambda x: x['query'])
+ ),
+ cqr_llm_chain
+ ),
+ "conversation_summary_chain",
+ log_output_template='conversation_summary_chain result: {output}',
+ message_id=state['message_id']
+ )
+
+ state['query'] = conversation_summary_chain.invoke(state)
+
+ return state_ori
+
+
+def query_preprocess(state: AppState):
+ state_ret = state
+ state = state['state']
+ rag_config = state['rag_config']
+ translate_config = rag_config['query_process_config']['translate_config']
+ translate_chain = RunnableLambda(
+ lambda x: query_translate(
+ x['query'],
+ lang=x['query_lang'],
+ translate_config=translate_config
+ )
+ )
+ lang_check_and_translate_chain = RunnablePassthrough.assign(
+ query_lang = RunnableLambda(lambda x:language_check(x['query']))
+ ) | RunnablePassthrough.assign(translated_text=translate_chain)
+
+ is_api_query_chain = RunnableLambda(lambda x:is_api_query(x['query']))
+ service_names_recognition_chain = RunnableLambda(lambda x:get_service_name(x['query']))
+
+ preprocess_chain = lang_check_and_translate_chain | RunnableParallelAssign(
+ is_api_query=is_api_query_chain,
+ service_names=service_names_recognition_chain
+ )
+
+ log_output_template=dedent("""
+ preprocess result:
+ query_lang: {query_lang}
+ translated_text: {translated_text}
+ is_api_query: {is_api_query}
+ service_names: {service_names}
+ """)
+ preprocess_chain = chain_logger(
+ preprocess_chain,
+ 'preprocess chain',
+ message_id=state['message_id'],
+ log_output_template=log_output_template
+ )
+ state = preprocess_chain.invoke(state)
+ state_ret['state'] = state
+ return state_ret
+
+def get_intent_recognition_with_index_chain(state):
+
+ intent_recognition_index = IntentRecognitionAOSIndex(
+ embedding_endpoint_name=state['intent_embedding_endpoint_name'])
+ intent_index_ingestion_chain = chain_logger(
+ intent_recognition_index.as_ingestion_chain(),
+ "intent_index_ingestion_chain",
+ message_id=state['message_id']
+ )
+ intent_index_check_exist_chain = RunnablePassthrough.assign(
+ is_intent_index_exist = intent_recognition_index.as_check_index_exist_chain()
+ )
+ intent_index_search_chain = chain_logger(
+ intent_recognition_index.as_search_chain(top_k=5),
+ "intent_index_search_chain",
+ message_id=state['message_id']
+ )
+ inten_postprocess_chain = intent_recognition_index.as_intent_postprocess_chain(method='top_1')
+
+ intent_search_and_postprocess_chain = intent_index_search_chain | inten_postprocess_chain
+ intent_branch = RunnableBranch(
+ (lambda x: not x['is_intent_index_exist'], intent_index_ingestion_chain | intent_search_and_postprocess_chain),
+ intent_search_and_postprocess_chain
+ )
+ intent_recognition_index_chain = intent_index_check_exist_chain | intent_branch
+ return intent_recognition_index_chain
+
+def get_qq_match_chain(state):
+ # qq_match
+ qq_workspace_list = state['qq_workspace_list']
+ rag_config = state['rag_config']
+
+ qq_match_threshold = rag_config['retriever_config']['qq_config']['qq_match_threshold']
+ qq_retriever_top_k = rag_config['retriever_config']['qq_config']['retriever_top_k']
+ retriever_list = [
+ QueryQuestionRetriever(
+ workspace,
+ # index=index["name"],
+ # vector_field=index["vector_field"],
+ # source_field=index["source_field"],
+ size=qq_retriever_top_k,
+ # lang=index["lang"],
+ # embedding_model_endpoint=index["embedding_endpoint"]
+ )
+ for workspace in qq_workspace_list
+ ]
+ qq_chain = MergerRetriever(retrievers=retriever_list) | \
+ RunnableLambda(retriever_results_format) |\
+ RunnableLambda(partial(
+ retriever_results_filter,
+ threshold=qq_match_threshold
+ ))
+ return qq_chain
+
+
+def qq_match_and_intent_recognition(state):
+ state_ret = state
+ state = state['state']
+ qq_chain = get_qq_match_chain(state)
+ intent_recognition_chain= get_intent_recognition_with_index_chain(state)
+
+ log_output_template=dedent("""
+ qq_result num: {qq_result_num}
+ intent recognition type: {intent_type}
+ """)
+ qq_and_intention_type_recognition_chain = chain_logger(
+ RunnableParallelAssign(
+ qq_result=qq_chain,
+ intent_type=intent_recognition_chain,
+ ) | RunnablePassthrough.assign(qq_result_num=lambda x:len(x['qq_result'])),
+ "qq_and_intention_type_recognition_chain",
+ log_output_template=log_output_template,
+ message_id=state['message_id']
+ )
+
+ state = qq_and_intention_type_recognition_chain.invoke(state)
+ state_ret['state'] = state
+ return state_ret
+
+def qd_retriver(state):
+ state_ret = state
+ state = state['state']
+ rag_config = state['rag_config']
+ qd_config = rag_config['retriever_config']['qd_config']
+ using_whole_doc = qd_config['using_whole_doc']
+ context_num = qd_config['context_num']
+ retriever_top_k = qd_config['retriever_top_k']
+ reranker_top_k = qd_config['reranker_top_k']
+ enable_reranker = qd_config['enable_reranker']
+
+ qd_workspace_list = state['qd_workspace_list']
+
+ retriever_list = [
+ QueryDocumentRetriever(
+ workspace=workspace,
+ using_whole_doc=using_whole_doc,
+ context_num=context_num,
+ top_k=retriever_top_k,
+ # "zh", zh_embedding_endpoint
+ )
+ for workspace in qd_workspace_list
+ ]
+
+ lotr = MergerRetriever(retrievers=retriever_list)
+ if enable_reranker:
+ compressor = BGEReranker(top_n=reranker_top_k)
+ else:
+ compressor = MergeReranker(top_n=reranker_top_k)
+ compression_retriever = ContextualCompressionRetriever(
+ base_compressor=compressor, base_retriever=lotr
+ )
+ qd_chain = RunnablePassthrough.assign(
+ docs=compression_retriever | RunnableLambda(retriever_results_format)
+ )
+ state = qd_chain.invoke(state)
+ state_ret['state'] = state
+ return state_ret
+
+def context_filter(state):
+ state_ret = state
+ state = state['state']
+ rag_config = state['rag_config']
+ qd_match_threshold = rag_config['retriever_config']['qd_config']['qd_match_threshold']
+ filtered_docs = retriever_results_filter(state['docs'],threshold=qd_match_threshold)
+ state['filtered_docs'] = filtered_docs
+ return state_ret
+
+def llm(state):
+ state_ret = state
+ state = state['state']
+ message_id = state['message_id']
+ stream = state['stream']
+ rag_config = state['rag_config']
+ generator_llm_config = rag_config['generator_llm_config']
+ context_num = generator_llm_config['context_num']
+ llm_chain = RunnableDictAssign(lambda x: contexts_trunc(x['docs'],context_num=context_num)) |\
+ RunnablePassthrough.assign(
+ answer=LLMChain.get_chain(
+ intent_type=IntentType.KNOWLEDGE_QA.value,
+ stream=stream,
+ **generator_llm_config
+ ),
+ chat_history=lambda x:rag_config['chat_history']
+ )
+
+ llm_chain = chain_logger(llm_chain,'llm_chain', message_id=message_id)
+ state = llm_chain.invoke(state)
+ state_ret['state'] = state
+ return state_ret
+
+
+def decide_intent(state):
+ state = state['state']
+ allow_intents = [
+ IntentType.KNOWLEDGE_QA.value,
+ IntentType.MARKET_EVENT.value
+ ]
+
+ if len(state['qq_result']) > 0:
+ state['answer'] = sorted(state['qq_result'],key=lambda x:x['score'],reverse=True)[0]['answer']
+ state['fast_info'] = 'qq_matched'
+ return 'mkt_fast_reply'
+
+ if state['intent_type'] not in allow_intents:
+ state['fast_info'] = f"unsupported intent type: {state['intent_type']}"
+ return 'mkt_fast_reply'
+
+ return 'qd_retriver'
+
+
+def decide_if_context_sufficient(state):
+ state = state['state']
+ if not state['filtered_docs']:
+ state['fast_info'] = ' insufficient context to answer the question'
+ return 'mkt_fast_reply'
+ return 'llm'
+
+def market_chain_knowledge_entry(
+ query_input: str,
+ stream=False,
+ manual_input_intent=None,
+ event_body=None,
+ rag_config=None,
+ message_id=None
+):
+ """
+ Entry point for the Lambda function.
+
+ :param query_input: The query input.
+ :param aos_index: The index of the AOS engine.
+ :param stream(Bool): Whether to use llm stream decoding output.
+ return: answer(str)
+ """
+ if rag_config is None:
+ rag_config = parse_config.parse_mkt_entry_knowledge_config(event_body)
+
+ assert rag_config is not None
+
+ logger.info(f'market rag knowledge configs:\n {json.dumps(rag_config,indent=2,ensure_ascii=False,cls=JSONEncoder)}')
+
+ workspace_ids = rag_config["retriever_config"]["workspace_ids"]
+ qq_workspace_list = []
+ qd_workspace_list = []
+ for workspace_id in workspace_ids:
+ workspace = workspace_manager.get_workspace(workspace_id)
+ if not workspace or "index_type" not in workspace:
+ logger.warning(f"workspace {workspace_id} not found")
+ continue
+ if workspace["index_type"] == "qq":
+ qq_workspace_list.append(workspace)
+ else:
+ qd_workspace_list.append(workspace)
+
+
+ debug_info = {}
+ contexts = []
+ sources = []
+ answer = ""
+
+ workflow = StateGraph(AppState)
+ workflow.add_node('mkt_fast_reply',mkt_fast_reply)
+ workflow.add_node('conversation_query_rewrite',conversation_query_rewrite)
+ workflow.add_node('query_preprocess',query_preprocess)
+ workflow.add_node('qq_match_and_intent_recognition',qq_match_and_intent_recognition)
+ workflow.add_node('qd_retriver',qd_retriver)
+ workflow.add_node('context_filter',context_filter)
+ workflow.add_node('llm',llm)
+
+ # start node
+ workflow.set_entry_point("conversation_query_rewrite")
+ # termial node
+ workflow.add_edge('mkt_fast_reply', END)
+ workflow.add_edge('llm', END)
+
+ # norm edge
+ workflow.add_edge('conversation_query_rewrite','query_preprocess')
+ workflow.add_edge(
+ 'query_preprocess',
+ 'qq_match_and_intent_recognition'
+ )
+
+ workflow.add_edge('qd_retriver','context_filter')
+
+ # conditional edges
+ workflow.add_conditional_edges(
+ 'qq_match_and_intent_recognition',
+ decide_intent,
+ {
+ "mkt_fast_reply": "mkt_fast_reply",
+ "qd_retriver": "qd_retriver"
+ })
+
+ workflow.add_conditional_edges(
+ "context_filter",
+ decide_if_context_sufficient,
+ {
+ "mkt_fast_reply":'mkt_fast_reply',
+ "llm":"llm"
+ }
+ )
+
+ app = workflow.compile()
+
+ inputs = {
+ "query": query_input,
+ "debug_info": debug_info,
+ # "intent_type": intent_type,
+ # "intent_info": intent_info,
+ "chat_history": rag_config['chat_history'],
+ "rag_config": rag_config,
+ "message_id": message_id,
+ "stream": stream,
+ "qq_workspace_list": qq_workspace_list,
+ "qd_workspace_list": qd_workspace_list,
+ "intent_embedding_endpoint_name": zh_embedding_endpoint
+ # "query_lang": "zh"
+ }
+ response = app.invoke({'state':inputs})['state']
+ # response = asyncio.run(full_chain.ainvoke(
+ # {
+ # "query": query_input,
+ # "debug_info": debug_info,
+ # # "intent_type": intent_type,
+ # # "intent_info": intent_info,
+ # "chat_history": rag_config['chat_history'],
+ # # "query_lang": "zh"
+ # }
+ # ))
+
+ answer = response["answer"]
+ sources = response["context_sources"]
+ contexts = response["context_docs"]
+
+ return answer, sources, contexts, debug_info
\ No newline at end of file
diff --git a/source/lambda/executor/utils/executor_entries/retriever_entries.py b/source/lambda/executor/utils/executor_entries/retriever_entries.py
new file mode 100644
index 00000000..ef68967a
--- /dev/null
+++ b/source/lambda/executor/utils/executor_entries/retriever_entries.py
@@ -0,0 +1,153 @@
+import asyncio
+import json
+
+from .mkt_entry_core import (
+ QueryQuestionRetriever,
+ get_query_process_chain,
+ auto_intention_recoginition_chain,
+ get_qd_chain
+)
+from langchain.schema.runnable import (
+ RunnableBranch,
+ RunnableLambda,
+ RunnableParallel,
+ RunnablePassthrough,
+)
+from ..time_utils import timeit
+from ..langchain_utils import chain_logger
+from .. import parse_config
+
+def get_strict_qq_chain(strict_q_q_index):
+ def get_strict_qq_result(docs, threshold=0.7):
+ results = []
+ for doc in docs:
+ if doc.metadata["score"] < threshold:
+ break
+ results.append(
+ {
+ "score": doc.metadata["score"],
+ "source": doc.metadata["source"],
+ "answer": doc.metadata["answer"],
+ "question": doc.metadata["question"],
+ }
+ )
+ return results
+
+
+ mkt_q_q_retriever = QueryQuestionRetriever(
+ index=strict_q_q_index,
+ vector_field="vector_field",
+ source_field="file_path",
+ size=5,
+ )
+ strict_q_q_chain = mkt_q_q_retriever | RunnableLambda(get_strict_qq_result)
+ return strict_q_q_chain
+
+
+def main_qq_retriever_entry(
+ query_input: str,
+ index: str,
+):
+ """
+ Entry point for the Lambda function.
+
+ :param query_input: The query input.
+ :param aos_index: The index of the AOS engine.
+
+ return: answer(str)
+ """
+ debug_info = {
+ "query": query_input,
+ "query_parser_info": {},
+ "q_q_match_info": {},
+ "knowledge_qa_knn_recall": {},
+ "knowledge_qa_boolean_recall": {},
+ "knowledge_qa_combined_recall": {},
+ "knowledge_qa_cross_model_sort": {},
+ "knowledge_qa_llm": {},
+ "knowledge_qa_rerank": {},
+ }
+ full_chain = get_strict_qq_chain(index)
+ response = full_chain.invoke({"query": query_input, "debug_info": debug_info})
+ return response
+
+
+@timeit
+def main_qd_retriever_entry(
+ query_input: str,
+ aos_index: str,
+ event_body=None,
+ manual_input_intent=None,
+ message_id=None
+):
+ """
+ Entry point for the Lambda function.
+
+ :param query_input: The query input.
+ :param aos_index: The index of the AOS engine.
+
+ return: answer(str)
+ """
+
+ rag_config=parse_config.parse_rag_config(event_body)
+
+ debug_info = {
+ "query": query_input,
+ "query_parser_info": {},
+ "q_q_match_info": {},
+ "knowledge_qa_knn_recall": {},
+ "knowledge_qa_boolean_recall": {},
+ "knowledge_qa_combined_recall": {},
+ "knowledge_qa_cross_model_sort": {},
+ "knowledge_qa_llm": {},
+ "knowledge_qa_rerank": {},
+ }
+ retriever_top_k = rag_config['retriever_config']['retriever_top_k']
+ using_whole_doc = rag_config['retriever_config']['using_whole_doc']
+ chunk_num = rag_config['retriever_config']['chunk_num']
+ query_process_chain = get_query_process_chain(
+ rag_config['chat_history'],
+ rag_config['query_process_config']['query_rewrite_config'],
+ rag_config['query_process_config']['conversation_query_rewrite_config'],
+ rag_config['query_process_config']['hyde_config']
+ )
+ intent_type = rag_config['intent_config']['intent_type']
+ intent_info = {
+ "manual_input_intent": manual_input_intent,
+ "strict_qq_intent_result": {},
+ }
+ intent_recognition_chain = auto_intention_recoginition_chain("aos_index_mkt_qq", message_id=message_id)
+ intent_recognition_chain = chain_logger(
+ intent_recognition_chain,
+ 'intention module',
+ log_output_template='intent chain output: {intent_type}',
+ message_id=message_id
+ )
+ qd_chain = get_qd_chain(
+ [aos_index], using_whole_doc=using_whole_doc, chunk_num=chunk_num, retriever_top_k=retriever_top_k, reranker_top_k=10
+ )
+ full_chain = query_process_chain | intent_recognition_chain | qd_chain
+ response = asyncio.run(full_chain.ainvoke({
+ "query": query_input,
+ "debug_info": debug_info,
+ "intent_type": intent_type,
+ "intent_info": intent_info,
+ }))
+ doc_list = []
+ for doc in response["docs"]:
+ doc_list.append({"page_content": doc.page_content, "metadata": doc.metadata})
+ return doc_list, debug_info
+
+
+
+def get_retriever_response(docs, debug_info):
+ response = {"statusCode": 200, "headers": {"Content-Type": "application/json"}}
+ resp_header = {
+ "Content-Type": "application/json",
+ "Access-Control-Allow-Headers": "Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token",
+ "Access-Control-Allow-Origin": "*",
+ "Access-Control-Allow-Methods": "*",
+ }
+ response["body"] = json.dumps({"docs": docs, "debug_info": debug_info})
+ response["headers"] = resp_header
+ return response
\ No newline at end of file
diff --git a/source/lambda/executor/utils/intent_utils.py b/source/lambda/executor/utils/intent_utils.py
index 7a69f6bb..1f3781be 100644
--- a/source/lambda/executor/utils/intent_utils.py
+++ b/source/lambda/executor/utils/intent_utils.py
@@ -235,7 +235,7 @@ def auto_intention_recoginition_chain(
# embedding_endpoint="",
# q_q_match_threshold=0.9,
intent_if_fail=IntentType.KNOWLEDGE_QA.value,
-
+ message_id=None
):
"""
@@ -253,6 +253,7 @@ def get_strict_intent(x):
x["intent_info"]["strict_qq_intent_result"] = x["q_q_match_res"]["answer"]
return x["intent_type"]
+
q_q_retriever = QueryQuestionRetriever(
index=q_q_retriever_config['index_q_q'],
vector_field="vector_field",
@@ -294,7 +295,8 @@ def get_strict_intent(x):
sub_intent_chain = chain_logger(
sub_intent_chain,
"sub intent chain",
- log_output_template='\nis_api_query: {is_api_query}.\nservice_names: {service_names}'
+ log_output_template='\nis_api_query: {is_api_query}.\nservice_names: {service_names}',
+ message_id=message_id
)
chain = intent_type_chain | RunnableBranch(
@@ -303,6 +305,19 @@ def get_strict_intent(x):
)
return chain
+
+
+
+
+# intent_recognition_with_opensearch
+def create_opensearch_index(opensearch_client):
+ pass
+
+
+def intent_recognition_with_openserach_chain(opensearch_client,top_k=5):
+ pass
+
+
diff --git a/source/lambda/executor/utils/intent_utils/__init__.py b/source/lambda/executor/utils/intent_utils/__init__.py
new file mode 100644
index 00000000..fe1c4c5f
--- /dev/null
+++ b/source/lambda/executor/utils/intent_utils/__init__.py
@@ -0,0 +1,2 @@
+from .intent_utils import auto_intention_recoginition_chain
+from .intent_aos_utils import IntentRecognitionAOSIndex
\ No newline at end of file
diff --git a/source/lambda/executor/utils/intent_utils/intent_aos_utils.py b/source/lambda/executor/utils/intent_utils/intent_aos_utils.py
new file mode 100644
index 00000000..5cb818c9
--- /dev/null
+++ b/source/lambda/executor/utils/intent_utils/intent_aos_utils.py
@@ -0,0 +1,187 @@
+from .. import retriever
+# from ..retriever import QueryDocumentRetriever, QueryQuestionRetriever,index_results_format
+from ..constant import IntentType,INTENT_RECOGNITION_TYPE
+# from functools import partial
+from langchain.schema.runnable import RunnablePassthrough, RunnableBranch, RunnableLambda
+# from ..llm_utils import Model as LLM_Model
+# from ..llm_utils.llm_chains import LLMChain
+# from langchain.prompts import PromptTemplate
+# import re
+
+from functools import lru_cache,partial
+import hashlib
+import traceback
+import threading
+import boto3
+import logging
+from ..prompt_template import INTENT_RECOGINITION_PROMPT_TEMPLATE_CLUADE,INTENT_RECOGINITION_EXAMPLE_TEMPLATE
+import os
+import json
+from typing import List,Dict
+from random import Random
+# from ..preprocess_utils import is_api_query,get_service_name
+from ..langchain_utils import chain_logger,RunnableNoneAssign
+from ..embeddings_utils import BGEEmbeddingSagemakerEndpoint
+from langchain_community.vectorstores.opensearch_vector_search import (
+ OpenSearchVectorSearch
+ )
+from langchain_community.embeddings.sagemaker_endpoint import (
+ SagemakerEndpointEmbeddings
+)
+
+from opensearchpy import RequestsHttpConnection
+from requests_aws4auth import AWS4Auth
+from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
+from langchain.docstore.document import Document
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+opensearch_client_lock = threading.Lock()
+abs_file_dir = os.path.dirname(__file__)
+intent_example_path = os.path.join(
+ abs_file_dir,
+ "intent_examples",
+ "examples.json"
+)
+
+class LangchainOpenSearchClient:
+ instance = None
+ def __new__(cls,index_name,embedding_endpoint_name,host=os.environ.get('aos_endpoint',None)):
+ identity = f'{index_name}_{host}_{embedding_endpoint_name}'
+ with opensearch_client_lock:
+ if cls.instance is not None and cls.instance._identity == identity:
+ return cls.instance
+ obj = cls.create(index_name,embedding_endpoint_name,host=host)
+ obj._identity = identity
+ cls.instance = obj
+ return obj
+
+ @classmethod
+ def create(cls,
+ index_name,
+ embedding_endpoint_name,
+ host=os.environ.get('aos_endpoint',None),
+ region_name=os.environ['AWS_REGION'],
+ ):
+ embedding = BGEEmbeddingSagemakerEndpoint(
+ endpoint_name=embedding_endpoint_name,
+ region_name=region_name
+ )
+ port = int(os.environ.get('AOS_PORT',443))
+ opensearch_url = f'https://{host}:{port}'
+ credentials = boto3.Session().get_credentials()
+ awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region_name, 'es', session_token=credentials.token)
+ opensearch_client = OpenSearchVectorSearch(
+ index_name=index_name,
+ embedding_function=embedding,
+ opensearch_url=opensearch_url
+ )
+ return opensearch_client
+
+
+class IntentRecognitionAOSIndex:
+ def __init__(
+ self,
+ intent_example_path=intent_example_path,
+ index_name=None,
+ embedding_endpoint_name=None,
+ host=os.environ.get('aos_endpoint',None)
+ ):
+ if index_name is None:
+ index_name = self.create_index_name(
+ embedding_endpoint_name=embedding_endpoint_name,
+ intent_example_path=intent_example_path
+ )
+ self.index_name = index_name
+ self.host = host
+ self.embedding_endpoint_name = embedding_endpoint_name
+ self.opensearch_client = LangchainOpenSearchClient(
+ index_name=index_name,
+ embedding_endpoint_name=embedding_endpoint_name,
+ host=host
+ )
+
+ @staticmethod
+ @lru_cache()
+ def create_index_name(
+ embedding_endpoint_name,
+ intent_example_path=intent_example_path
+ ):
+ index_name = f"intent_recognition_{embedding_endpoint_name}_{hashlib.md5(open(intent_example_path,'rb').read()).hexdigest()}"
+ return index_name
+
+ def check_index_exist(self):
+ if_exist = self.opensearch_client.client.indices.exists(self.index_name)
+ logger.info(f'is {self.index_name} exist: {if_exist}')
+ return if_exist
+
+ def ingestion_intent_data(self):
+ docs = []
+ intent_examples = json.load(open(intent_example_path))['examples']
+ for intent_name,examples in intent_examples.items():
+ for example in examples:
+ doc = Document(
+ page_content=example,
+ metadata = {"intent":intent_name}
+ )
+ docs.append(doc)
+ logger.info(f'ingestion intent doc, num: {len(docs)}, index_name: {self.index_name}')
+ self.opensearch_client.add_documents(
+ docs
+ )
+ logger.info(f'ingestion intent doc, num: {len(docs)}')
+
+ def search(self,query,top_k=5):
+
+ r_docs = self.opensearch_client.similarity_search_with_score(
+ query = query,
+ k=top_k
+ )
+ # r_docs = opensearch_client.similarity_search(datum['question'],k=1)
+ ret = [{
+ "candidate_query":r_doc[0].page_content,
+ "intent":r_doc[0].metadata["intent"],
+ "score":r_doc[1],
+ "origin_query":query
+ }
+ for r_doc in r_docs
+ ]
+ return ret
+
+ def intent_postprocess_top_1(self,retriever_list:list[dict]):
+ retriever_list = sorted(retriever_list, key = lambda x: x['score'])
+ intent = retriever_list[-1]['intent']
+ assert IntentType.has_value(intent), intent
+ return intent
+
+ def as_check_index_exist_chain(self):
+ return RunnableLambda(lambda x: self.check_index_exist())
+
+ def as_search_chain(self,top_k=5):
+ return RunnableLambda(lambda x: self.search(x['query'],top_k=top_k))
+
+ def as_ingestion_chain(self):
+ chain = RunnableNoneAssign(lambda x: self.ingestion_intent_data())
+ return chain
+
+ def as_intent_postprocess_chain(self,method='top_1'):
+ if method == 'top_1':
+ chain = RunnableLambda(self.intent_postprocess_top_1)
+ return chain
+ else:
+ raise TypeError(f'invalid method {method}')
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/source/lambda/executor/utils/intent_examples/chit_chat.example b/source/lambda/executor/utils/intent_utils/intent_examples/chit_chat.example
similarity index 100%
rename from source/lambda/executor/utils/intent_examples/chit_chat.example
rename to source/lambda/executor/utils/intent_utils/intent_examples/chit_chat.example
diff --git a/source/lambda/executor/utils/intent_examples/examples.json b/source/lambda/executor/utils/intent_utils/intent_examples/examples.json
similarity index 87%
rename from source/lambda/executor/utils/intent_examples/examples.json
rename to source/lambda/executor/utils/intent_utils/intent_examples/examples.json
index cbf293f5..35a5a3a7 100644
--- a/source/lambda/executor/utils/intent_examples/examples.json
+++ b/source/lambda/executor/utils/intent_utils/intent_examples/examples.json
@@ -9,7 +9,13 @@
"intent": "knowledge_qa",
"describe": "AWS Question and Answer (Q&A)",
"index": "B"
+ },
+ {
+ "intent": "market_event",
+ "describe": "AWS market event",
+ "index": "C"
}
+
],
"examples": {
"chat": [
@@ -70,6 +76,12 @@
"what is baywatch?",
"怎么看现有的Capacity?",
"申请gpu foob的流程是什么?"
+ ],
+ "market_event":[
+ "2024北京国际车展上,亚马逊云科技会参加吗?",
+ "3月份在深圳有生成式AI的活动吗?",
+ "2024年会举办出海全球化论坛吗?",
+ "2024年出海全球化论坛的会议日程是什么?"
]
}
}
\ No newline at end of file
diff --git a/source/lambda/executor/utils/intent_examples/knowledge_qa.example b/source/lambda/executor/utils/intent_utils/intent_examples/knowledge_qa.example
similarity index 100%
rename from source/lambda/executor/utils/intent_examples/knowledge_qa.example
rename to source/lambda/executor/utils/intent_utils/intent_examples/knowledge_qa.example
diff --git a/source/lambda/executor/utils/intent_utils/intent_utils.py b/source/lambda/executor/utils/intent_utils/intent_utils.py
new file mode 100644
index 00000000..510309f9
--- /dev/null
+++ b/source/lambda/executor/utils/intent_utils/intent_utils.py
@@ -0,0 +1,330 @@
+from .. import retriever
+from ..retriever import QueryDocumentRetriever, QueryQuestionRetriever,index_results_format
+from ..constant import IntentType,INTENT_RECOGNITION_TYPE
+from functools import partial
+from langchain.schema.runnable import RunnablePassthrough, RunnableBranch, RunnableLambda
+# from ..llm_utils import Model as LLM_Model
+from ..llm_utils.llm_chains import LLMChain
+# from langchain.prompts import PromptTemplate
+import re
+import traceback
+# from ..prompt_template import INTENT_RECOGINITION_PROMPT_TEMPLATE_CLUADE,INTENT_RECOGINITION_EXAMPLE_TEMPLATE
+import os
+import json
+from random import Random
+from ..preprocess_utils import is_api_query,get_service_name
+from ..langchain_utils import chain_logger
+
+abs_file_dir = os.path.dirname(__file__)
+
+# intent_map = {
+# "闲聊": IntentType.CHAT.value,
+# "知识问答": IntentType.KNOWLEDGE_QA.value
+# }
+
+# class IntentUtils:
+# def __init__(self,
+# intent_save_path=os.path.join(abs_file_dir,"intent_examples/examples.json"),
+# example_template=INTENT_RECOGINITION_EXAMPLE_TEMPLATE,
+# llm_model_id = 'anthropic.claude-v2:1',
+# llm_model_kwargs={"temperature":0,
+# "max_tokens_to_sample": 2000,
+# "stop_sequences": ["\n\n","\n\nHuman:"]
+# },
+# seed = 42
+# ):
+# self.intent_few_shot_examples = json.load(open(intent_save_path))
+# self.intent_indexs = {intent_d['intent']:intent_d['index'] for intent_d in self.intent_few_shot_examples['intents']}
+# self.index_intents = {v:k for k,v in self.intent_indexs.items()}
+# self.intents = list(self.intent_few_shot_examples['examples'].keys())
+# self.few_shot_examples = self.create_few_shot_examples()
+# Random(seed).shuffle(self.few_shot_examples)
+# self.examples_str = self.create_few_shot_example_string(example_template=example_template)
+# self.categories_str = self.create_all_labels_string()
+# self.intent_recognition_template = PromptTemplate.from_template(INTENT_RECOGINITION_PROMPT_TEMPLATE_CLUADE)
+# self.llm = LLM_Model.get_model(llm_model_id,model_kwargs=llm_model_kwargs)
+# self.intent_recognition_llm_chain = self.intent_recognition_template | self.llm
+# def create_few_shot_examples(self):
+# ret = []
+# for intent in self.intents:
+# examples = self.intent_few_shot_examples['examples'][intent]
+# for query in examples:
+# ret.append({
+# "intent":intent,
+# "query": query
+# })
+# return ret
+
+# def create_few_shot_example_string(self,example_template=INTENT_RECOGINITION_EXAMPLE_TEMPLATE):
+# example_strs = []
+# intent_indexs = self.intent_indexs
+# for example in self.few_shot_examples:
+# example_strs.append(
+# example_template.format(
+# label=intent_indexs[example['intent']],
+# query=example['query']
+# )
+# )
+# return '\n\n'.join(example_strs)
+
+# def create_all_labels_string(self):
+# intent_few_shot_examples = self.intent_few_shot_examples
+# label_strs = []
+# labels = intent_few_shot_examples['intents']
+# for i,label in enumerate(labels):
+# label_strs.append(f"({label['index']}) {label['describe']}")
+# return "\n".join(label_strs)
+
+# def postprocess(self,output:str):
+# out = output.strip()
+# assert out, output
+# return self.index_intents[out[0]]
+
+# intention_obj = IntentUtils()
+
+# def create_few_shot_example_string(examples):
+# example_strs = []
+# for example in examples:
+# example_strs.append(
+# INTENT_RECOGINITION_EXAMPLE_TEMPLATE.format(
+# label=example['label'],
+# query=example['query']
+# )
+# )
+# return '\n'.join(example_strs)
+
+# def create_all_labels_string(labels):
+# label_strs = []
+# for i,label in enumerate(labels):
+# label_strs.append(f"- {label}")
+# return "\n".join(label_strs)
+
+
+# def postprocess(output:str):
+# output = f'{output}'
+# r = re.findall('(.*?)',output,re.S)
+# assert r, output
+# r = [rr.strip() for rr in r]
+# r = [rr for rr in r if rr]
+# assert r, output
+# return r[0]
+
+# def get_intent_with_claude(query,intent_if_fail,debug_info):
+# predict_label = None
+# try:
+# r = intention_obj.intent_recognition_llm_chain.invoke({
+# "categories":intention_obj.categories_str,
+# "examples":intention_obj.examples_str,
+# 'query':query})
+# predict_label = intention_obj.postprocess(r)
+# except:
+# print(traceback.format_exc())
+# predict_label
+
+
+# intent = predict_label or intent_if_fail
+# debug_info['intent_debug_info'] = {
+# 'llm_output':r,
+# 'origin_intent':predict_label,
+# 'intent': intent
+# }
+# return intent
+
+# def get_intent(query,intent_type,qq_index=None):
+# assert IntentType.has_value(intent_type),intent_type
+# if intent_type != IntentType.AUTO:
+# return intent_type
+
+ # return get_intent_with_claude(query)
+
+# def auto_intention_recoginition_chain(
+# index_q_q,
+# lang="zh",
+# embedding_endpoint="",
+# q_q_match_threshold=0.9,
+# intent_if_fail=IntentType.KNOWLEDGE_QA.value
+# ):
+# """
+
+# Args:
+# index_q_q (_type_): _description_
+# q_q_match_threshold (float, optional): _description_. Defaults to 0.9.
+# """
+# def get_custom_intent_type(x):
+# assert IntentType.has_value(x["intent_type"]), x["intent_type"]
+# return x["intent_type"]
+
+# def get_strict_intent(x):
+# x["intent_type"] = IntentType.STRICT_QQ.value
+# x["intent_info"]["strict_qq_intent_result"] = x["q_q_match_res"]["answer"]
+# return x["intent_type"]
+
+# q_q_retriever = QueryQuestionRetriever(
+# index=index_q_q, vector_field="vector_field", source_field="file_path", size=5, lang=lang, embedding_model_endpoint=embedding_endpoint)
+
+# strict_q_q_chain = q_q_retriever | RunnableLambda(partial(index_results_format,threshold=0))
+
+
+# intent_type_auto_recognition_chain = RunnablePassthrough.assign(
+# q_q_match_res=strict_q_q_chain
+# ) | RunnableBranch(
+# # (lambda x: len(x['q_q_match_res']["answer"]) > 0, RunnableLambda(lambda x: IntentType.STRICT_QQ.value)),
+# (
+# lambda x: x['q_q_match_res']["answer"][0]["score"] < q_q_match_threshold and x["intent_type"] == IntentType.AUTO.value,
+# RunnableLambda(lambda x: get_intent_with_claude(x['query'],intent_if_fail,x['debug_info']))
+# ),
+# RunnableLambda(lambda x: get_strict_intent(x))
+# )
+
+# intent_type_chain = RunnablePassthrough.assign(
+# intent_type=RunnableBranch(
+# (
+# lambda x:x["intent_type"] == IntentType.AUTO.value or x["intent_type"] == IntentType.STRICT_QQ.value,
+# intent_type_auto_recognition_chain
+# ),
+# RunnableLambda(get_custom_intent_type)
+# )
+# )
+
+# # add 2nd stage intent chain here, e.g. knowledge_qa
+# sub_intent_chain = RunnablePassthrough.assign(
+# is_api_query = RunnableLambda(lambda x:is_api_query(x['query'])),
+# service_names = RunnableLambda(lambda x:get_service_name(x['query']))
+# )
+# sub_intent_chain = chain_logger(
+# sub_intent_chain,
+# "sub intent chain",
+# log_output_template='\nis_api_query: {is_api_query}.\nservice_names: {service_names}'
+# )
+
+# chain = intent_type_chain | RunnableBranch(
+# (lambda x:x["intent_type"] == IntentType.KNOWLEDGE_QA.value, sub_intent_chain),
+# RunnablePassthrough()
+# )
+
+# return chain
+
+def get_intent_with_llm(query,intent_if_fail,debug_info,intent_config):
+ chain = LLMChain.get_chain(
+ **{**intent_config,"intent_type":INTENT_RECOGNITION_TYPE}
+ )
+ predict_label = None
+ error_str = None
+ try:
+ predict_label = chain.invoke({
+ "query": query
+ })
+ except:
+ error_str = traceback.format_exc()
+ print(error_str)
+
+ intent = predict_label or intent_if_fail
+ debug_info['intent_debug_info'] = {
+ 'llm_output': predict_label,
+ 'origin_intent': predict_label,
+ 'intent': intent,
+ 'error': error_str
+ }
+ return intent
+
+
+def auto_intention_recoginition_chain(
+ q_q_retriever_config = None,
+ intent_config=None,
+ # index_q_q,
+ # lang="zh",
+ # embedding_endpoint="",
+ # q_q_match_threshold=0.9,
+ intent_if_fail=IntentType.KNOWLEDGE_QA.value,
+ message_id=None
+ ):
+ """
+
+ Args:
+ index_q_q (_type_): _description_
+ q_q_match_threshold (float, optional): _description_. Defaults to 0.9.
+ """
+ assert q_q_retriever_config is not None and intent_config is not None
+ def get_custom_intent_type(x):
+ assert IntentType.has_value(x["intent_type"]), x["intent_type"]
+ return x["intent_type"]
+
+ def get_strict_intent(x):
+ x["intent_type"] = IntentType.STRICT_QQ.value
+ x["intent_info"]["strict_qq_intent_result"] = x["q_q_match_res"]["answer"]
+ return x["intent_type"]
+
+
+ q_q_retriever = QueryQuestionRetriever(
+ index=q_q_retriever_config['index_q_q'],
+ vector_field="vector_field",
+ source_field="file_path",
+ size=5,
+ lang=q_q_retriever_config['lang'],
+ embedding_model_endpoint=q_q_retriever_config['embedding_endpoint']
+ )
+
+ strict_q_q_chain = q_q_retriever | RunnableLambda(partial(index_results_format,threshold=0))
+
+ q_q_match_threshold = q_q_retriever_config['q_q_match_threshold']
+ intent_type_auto_recognition_chain = RunnablePassthrough.assign(
+ q_q_match_res=strict_q_q_chain
+ ) | RunnableBranch(
+ # (lambda x: len(x['q_q_match_res']["answer"]) > 0, RunnableLambda(lambda x: IntentType.STRICT_QQ.value)),
+ (
+ lambda x: x['q_q_match_res']["answer"][0]["score"] < q_q_match_threshold and x["intent_type"] == IntentType.AUTO.value,
+ RunnableLambda(lambda x: get_intent_with_llm(x['query'],intent_if_fail,x['debug_info'],intent_config=intent_config))
+ ),
+ RunnableLambda(lambda x: get_strict_intent(x))
+ )
+
+ intent_type_chain = RunnablePassthrough.assign(
+ intent_type=RunnableBranch(
+ (
+ lambda x:x["intent_type"] == IntentType.AUTO.value or x["intent_type"] == IntentType.STRICT_QQ.value,
+ intent_type_auto_recognition_chain
+ ),
+ RunnableLambda(get_custom_intent_type)
+ )
+ )
+
+ # add 2nd stage intent chain here, e.g. knowledge_qa
+ sub_intent_chain = RunnablePassthrough.assign(
+ is_api_query = RunnableLambda(lambda x:is_api_query(x['query'])),
+ service_names = RunnableLambda(lambda x:get_service_name(x['query']))
+ )
+ sub_intent_chain = chain_logger(
+ sub_intent_chain,
+ "sub intent chain",
+ log_output_template='\nis_api_query: {is_api_query}.\nservice_names: {service_names}',
+ message_id=message_id
+ )
+
+ chain = intent_type_chain | RunnableBranch(
+ (lambda x:x["intent_type"] == IntentType.KNOWLEDGE_QA.value, sub_intent_chain),
+ RunnablePassthrough()
+ )
+
+ return chain
+
+
+
+
+# intent_recognition_with_opensearch
+def create_opensearch_index(opensearch_client):
+ pass
+
+
+def intent_recognition_with_openserach_chain(opensearch_client,top_k=5):
+ pass
+
+
+
+
+
+
+
+
+
+
+
diff --git a/source/lambda/executor/utils/langchain_utils.py b/source/lambda/executor/utils/langchain_utils.py
index 896be14f..0debf8ce 100644
--- a/source/lambda/executor/utils/langchain_utils.py
+++ b/source/lambda/executor/utils/langchain_utils.py
@@ -1,3 +1,4 @@
+import time
from langchain.schema.runnable.base import Runnable,RunnableLambda
from langchain.schema.runnable import RunnablePassthrough
from functools import partial
@@ -5,16 +6,68 @@
# import threading
# import time
from .logger_utils import logger
+from langchain.schema.runnable import RunnableLambda,RunnablePassthrough,RunnableParallel
+
+class RunnableDictAssign:
+ """
+ example:
+ def fn(x):
+ return {"a":1,"b":2}
+
+ chain = RunnableDictAssign(fn)
+ chain.invoke({"c":3})
+
+ ## output
+ {"c":3,"a":1,"b":2}
+ """
+ def __new__(cls,fn):
+ assert callable(fn)
+ def _merge_keys(x:dict,key='__temp_dict'):
+ d = x.pop(key)
+ return {**x,**d}
+ chain = RunnablePassthrough.assign(__temp_dict=fn) | RunnableLambda(lambda x: _merge_keys(x))
+ return chain
+
+class RunnableParallelAssign:
+ """
+ example:
+ def fn(x):
+ return {"a":1,"b":2}
+
+ chain = RunnableDictAssign(fn)
+ chain.invoke({"c":3})
+
+ ## output
+ {"c":3,"a":1,"b":2}
+ """
+ def __new__(cls,**kwargs):
+ def _merge_keys(x:dict,key='__temp_dict'):
+ d = x.pop(key)
+ return {**x,**d}
+ chain = RunnablePassthrough.assign(__temp_dict=RunnableParallel(**kwargs)) | RunnableLambda(lambda x: _merge_keys(x))
+ return chain
+
+class RunnableNoneAssign:
+ """
+ example:
+ def fn(x):
+ return None
+
+ chain = RunnableNoneAssign(fn)
+ chain.invoke({"c":3})
+
+ ## output
+ {"c":3}
+ """
+ def __new__(cls,fn):
+ assert callable(fn)
+ def _remove_keys(x:dict,key='__temp_dict'):
+ x.pop(key)
+ return x
+ chain = RunnablePassthrough.assign(__temp_dict=fn) | RunnableLambda(lambda x: _remove_keys(x))
+ return chain
+
-class LmabdaDict(dict):
- """add lambda to value"""
- def __init__(self,**kwargs):
- super().__init__(**kwargs)
- for k in list(self.keys()):
- v = self[k]
- if not callable(v) or not isinstance(v,Runnable):
- self[k] = lambda x:x
-
def create_identity_lambda(keys:list):
if isinstance(keys,str):
keys = [keys]
@@ -37,19 +90,23 @@ class LogTimeListener:
def __init__(
self,
chain_name,
+ message_id="",
log_input=False,
log_output=False,
log_input_template=None,
log_output_template=None
):
self.chain_name = chain_name
+ self.message_id = message_id
self.log_input = log_input
self.log_output = log_output
self.log_input_template = log_input_template
self.log_output_template = log_output_template
+ self.message_id = message_id
+ self.start_time = None
def on_start(self,run):
- logger.info(f'Enter: {self.chain_name}')
+ logger.info(f'{self.message_id} Enter: {self.chain_name}')
if self.log_input:
logger.info(f"Inputs({self.chain_name}): {run.inputs}")
if self.log_input_template:
@@ -59,19 +116,22 @@ def on_end(self,run):
logger.info(f'Outputs({self.chain_name}): {run.outputs}')
if self.log_output_template:
- logger.info(self.log_output_template.format(**run.outputs))
-
+ if isinstance(run.outputs,dict):
+ logger.info(self.log_output_template.format(**run.outputs))
+ else:
+ logger.info(self.log_output_template.format(run.outputs))
exe_time = (run.end_time - run.start_time).total_seconds()
- logger.info(f'Exit: {self.chain_name}, elpase time(s): {exe_time}')
+ logger.info(f'{self.message_id} Exit: {self.chain_name}, elpase time(s): {exe_time}')
+ logger.info(f'{self.message_id} running time of {self.chain_name}: {exe_time}s')
def on_error(self,run):
raise
# logger.info(f"Error in run chain: {self.chain_name}.")
-
def chain_logger(
chain,
chain_name,
+ message_id=None,
log_input=False,
log_output=False,
log_input_template=None,
@@ -79,6 +139,7 @@ def chain_logger(
):
obj = LogTimeListener(
chain_name,
+ message_id,
log_input=log_input,
log_output=log_output,log_input_template=log_input_template,
log_output_template=log_output_template
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/__init__.py b/source/lambda/executor/utils/llm_utils/llm_chains/__init__.py
index 4ba0418b..e114e226 100644
--- a/source/lambda/executor/utils/llm_utils/llm_chains/__init__.py
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/__init__.py
@@ -4,6 +4,7 @@
Claude21ChatChain,
ClaudeInstanceChatChain,
Iternlm2Chat7BChatChain,
+ Iternlm2Chat20BChatChain,
Baichuan2Chat13B4BitsChatChain
)
@@ -11,14 +12,15 @@
Iternlm2Chat7BConversationSummaryChain,
Claude2ConversationSummaryChain,
Claude21ConversationSummaryChain,
- Iternlm2Chat7BConversationSummaryChain
+ Iternlm2Chat20BConversationSummaryChain
)
from .intention_chain import (
Claude21IntentRecognitionChain,
Claude2IntentRecognitionChain,
ClaudeInstanceIntentRecognitionChain,
- Iternlm2Chat7BIntentRecognitionChain
+ Iternlm2Chat7BIntentRecognitionChain,
+ Iternlm2Chat20BIntentRecognitionChain
)
from .rag_chain import (
@@ -26,12 +28,14 @@
Claude2RagLLMChain,
ClaudeRagInstance,
Baichuan2Chat13B4BitsKnowledgeQaChain,
- Iternlm2Chat7BKnowledgeQaChain
+ Iternlm2Chat7BKnowledgeQaChain,
+ Iternlm2Chat20BKnowledgeQaChain
)
from .translate_chain import (
- Iternlm2Chat7BChatChain
+ Iternlm2Chat7BTranslateChain,
+ Iternlm2Chat20BTranslateChain
)
@@ -39,5 +43,31 @@
Claude21MKTConversationSummaryChain,
ClaudeInstanceMKTConversationSummaryChain,
Claude2MKTConversationSummaryChain,
- Iternlm2Chat7BMKTConversationSummaryChain
+ Iternlm2Chat7BMKTConversationSummaryChain,
+ Iternlm2Chat20BMKTConversationSummaryChain
)
+
+from .stepback_chain import (
+ Claude21StepBackChain,
+ ClaudeInstanceStepBackChain,
+ Claude2StepBackChain,
+ Iternlm2Chat7BStepBackChain,
+ Iternlm2Chat20BStepBackChain
+)
+
+
+from .hyde_chain import (
+ Claude21HydeChain,
+ Claude2HydeChain,
+ ClaudeInstanceHydeChain,
+ Iternlm2Chat20BHydeChain,
+ Iternlm2Chat7BHydeChain
+)
+
+from .query_rewrite_chain import (
+ Claude21QueryRewriteChain,
+ Claude2QueryRewriteChain,
+ ClaudeInstanceQueryRewriteChain,
+ Iternlm2Chat20BQueryRewriteChain,
+ Iternlm2Chat7BQueryRewriteChain
+)
\ No newline at end of file
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/chat_chain.py b/source/lambda/executor/utils/llm_utils/llm_chains/chat_chain.py
index a3bf6bf4..e668e06b 100644
--- a/source/lambda/executor/utils/llm_utils/llm_chains/chat_chain.py
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/chat_chain.py
@@ -51,7 +51,7 @@ class Baichuan2Chat13B4BitsChatChain(LLMChain):
"temperature": 0.3,
"top_k": 5,
"top_p": 0.85,
- "repetition_penalty": 1.05,
+ # "repetition_penalty": 1.05,
"do_sample": True
}
@@ -59,7 +59,6 @@ class Baichuan2Chat13B4BitsChatChain(LLMChain):
def create_chain(cls, model_kwargs=None, **kwargs):
stream = kwargs.get('stream',False)
# chat_history = kwargs.pop('chat_history',[])
-
model_kwargs = model_kwargs or {}
model_kwargs.update({"stream": stream})
model_kwargs = {**cls.default_model_kwargs,**model_kwargs}
@@ -72,7 +71,6 @@ def create_chain(cls, model_kwargs=None, **kwargs):
llm_chain = RunnableLambda(lambda x:llm.invoke(x,stream=stream))
return llm_chain
-
class Iternlm2Chat7BChatChain(LLMChain):
model_id = "internlm2-chat-7b"
intent_type = IntentType.CHAT.value
@@ -147,3 +145,7 @@ def create_chain(cls, model_kwargs=None, **kwargs):
llm_chain = prompt_template | RunnableLambda(lambda x:llm.invoke(x,stream=stream))
return llm_chain
+
+
+class Iternlm2Chat20BChatChain(Iternlm2Chat7BChatChain):
+ model_id = "internlm2-chat-20b"
\ No newline at end of file
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/conversation_summary_chain.py b/source/lambda/executor/utils/llm_utils/llm_chains/conversation_summary_chain.py
index 21ce7235..efda7006 100644
--- a/source/lambda/executor/utils/llm_utils/llm_chains/conversation_summary_chain.py
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/conversation_summary_chain.py
@@ -34,7 +34,8 @@ class Iternlm2Chat7BConversationSummaryChain(Iternlm2Chat7BChatChain):
"""
default_model_kwargs = {
"max_new_tokens": 300,
- "temperature": 0.0
+ "temperature": 0.1,
+ "stop_tokens":["\n\n"]
}
@classmethod
def create_prompt(cls,x):
@@ -47,7 +48,7 @@ def create_prompt(cls,x):
else:
conversational_contexts.append(f"A: {his.content}")
- conversational_context = "\n".join(conversational_contexts)
+ conversational_context = '[' + "\n".join(conversational_contexts) + ']'
prompt = cls.build_prompt(
cls.meta_instruction_prompt_template.format(
conversational_context=conversational_context,
@@ -62,6 +63,40 @@ def create_chain(cls, model_kwargs=None, **kwargs):
model_kwargs = {**cls.default_model_kwargs,**model_kwargs}
return super().create_chain(model_kwargs=model_kwargs,**kwargs)
+class Iternlm2Chat20BConversationSummaryChain(Iternlm2Chat7BConversationSummaryChain):
+ model_id = "internlm2-chat-20b"
+ meta_instruction_prompt_template = """Given the following conversation and a follow up question, rephrase the follow up \
+question to be a standalone question.
+
+Chat History:
+{history}
+Follow Up Input: {question}"""
+ default_model_kwargs = {
+ "max_new_tokens": 300,
+ "temperature": 0.1,
+ "stop_tokens":["\n\n"]
+ }
+ @classmethod
+ def create_prompt(cls,x):
+ chat_history = x['chat_history']
+ conversational_contexts = []
+ for his in chat_history:
+ assert his.type in [HUMAN_MESSAGE_TYPE,AI_MESSAGE_TYPE]
+ if his.type == HUMAN_MESSAGE_TYPE:
+ conversational_contexts.append(f"Q: {his.content}")
+ else:
+ conversational_contexts.append(f"A: {his.content}")
+
+ conversational_context = "\n".join(conversational_contexts)
+ prompt = cls.build_prompt(
+ cls.meta_instruction_prompt_template.format(
+ history=conversational_context,
+ question=x['query'])
+ )
+ prompt = prompt + "Standalone Question: "
+ return prompt
+
+
class Claude2ConversationSummaryChain(LLMChain):
model_id = 'anthropic.claude-v2'
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/hyde_chain.py b/source/lambda/executor/utils/llm_utils/llm_chains/hyde_chain.py
new file mode 100644
index 00000000..3e1a42b4
--- /dev/null
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/hyde_chain.py
@@ -0,0 +1,99 @@
+# hyde
+import json
+import os
+import re
+import sys
+from random import Random
+from functools import lru_cache
+
+from .llm_chain_base import LLMChain
+from ...constant import HYDE_TYPE
+from ..llm_models import Model as LLM_Model
+
+from langchain.prompts import PromptTemplate
+from langchain.schema.runnable import RunnablePassthrough, RunnableBranch, RunnableLambda
+from .chat_chain import Iternlm2Chat7BChatChain
+from ..llm_chains import LLMChain
+
+
+
+
+
+WEB_SEARCH_TEMPLATE = """Please write a passage to answer the question
+Question: {query}
+Passage:"""
+hyde_web_search_template = PromptTemplate(template=WEB_SEARCH_TEMPLATE, input_variables=["query"])
+
+
+class Claude2HydeChain(LLMChain):
+ model_id = 'anthropic.claude-v2'
+ intent_type = HYDE_TYPE
+
+ default_model_kwargs = {
+ "temperature": 0.5,
+ "max_tokens_to_sample": 1000,
+ "stop_sequences": [
+ "\n\nHuman:"
+ ]
+ }
+
+ @classmethod
+ def create_chain(cls, model_kwargs=None, **kwargs):
+ query_key = kwargs.pop('query_key','query')
+ model_kwargs = model_kwargs or {}
+ model_kwargs = {**cls.default_model_kwargs,**model_kwargs}
+
+ llm = LLM_Model.get_model(
+ model_id=cls.model_id,
+ model_kwargs=model_kwargs,
+ return_chat_model=False
+ )
+ chain = RunnablePassthrough.assign(
+ hyde_doc = RunnableLambda(lambda x: hyde_web_search_template.invoke({"query": x[query_key]})) | llm
+ )
+ return chain
+
+class Claude21HydeChain(Claude2HydeChain):
+ model_id = 'anthropic.claude-v2:1'
+
+
+class ClaudeInstanceHydeChain(Claude2HydeChain):
+ model_id = 'anthropic.claude-instant-v1'
+
+
+
+internlm2_meta_instruction = "You are a helpful AI Assistant."
+
+class Iternlm2Chat7BHydeChain(Iternlm2Chat7BChatChain):
+ model_id = "internlm2-chat-7b"
+ intent_type = HYDE_TYPE
+
+ default_model_kwargs = {
+ "temperature":0.1,
+ "max_new_tokens": 200
+ }
+
+ @classmethod
+ def create_prompt(cls,x):
+ query = f"""Please write a brief passage to answer the question. \nQuestion: {prompt}"""
+ prompt = cls.build_prompt(
+ query = query,
+ meta_instruction=internlm2_meta_instruction,
+ ) + "Passage: "
+ return prompt
+
+class Iternlm2Chat20BHydeChain(Iternlm2Chat7BHydeChain):
+ model_id = "internlm2-chat-20b"
+ intent_type = HYDE_TYPE
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/intention_chain.py b/source/lambda/executor/utils/llm_utils/llm_chains/intention_chain.py
index c20f8cca..f4f80923 100644
--- a/source/lambda/executor/utils/llm_utils/llm_chains/intention_chain.py
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/intention_chain.py
@@ -16,6 +16,7 @@
intent_save_path = os.path.join(
os.path.dirname(os.path.dirname(abs_dir)),
+ 'intent_utils',
"intent_examples",
"examples.json"
)
@@ -49,7 +50,7 @@ class Iternlm2Chat7BIntentRecognitionChain(Iternlm2Chat7BChatChain):
intent_type = INTENT_RECOGNITION_TYPE
default_model_kwargs = {
- "temperature":0.0,
+ "temperature":0.1,
"max_new_tokens": 100,
"stop_tokens": ["\n",'。','.']
}
@@ -81,7 +82,7 @@ def create_prompt(cls,x):
@staticmethod
def postprocess(intent):
- intent = intent.replace('。',"").replace('.',"").strip()
+ intent = intent.replace('。',"").replace('.',"").strip().strip('**')
r = load_intention_file(intent_save_path)
intent_indexs = r['intent_indexs']
assert intent in intent_indexs, (intent,intent_indexs)
@@ -98,6 +99,8 @@ def create_chain(cls, model_kwargs=None, **kwargs):
chain = chain | RunnableLambda(lambda x:cls.postprocess(x))
return chain
+class Iternlm2Chat20BIntentRecognitionChain(Iternlm2Chat7BIntentRecognitionChain):
+ model_id = "internlm2-chat-20b"
class Claude2IntentRecognitionChain(LLMChain):
model_id = 'anthropic.claude-v2'
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/mkt_conversation_summary.py b/source/lambda/executor/utils/llm_utils/llm_chains/mkt_conversation_summary.py
index 94bb1bba..ed515d94 100644
--- a/source/lambda/executor/utils/llm_utils/llm_chains/mkt_conversation_summary.py
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/mkt_conversation_summary.py
@@ -72,6 +72,10 @@ def create_chain(cls, model_kwargs=None, **kwargs):
chain = chain | RunnableLambda(lambda x: x['prefix'] + x['llm_output'])
return chain
+
+class Iternlm2Chat20BMKTConversationSummaryChain(Iternlm2Chat7BMKTConversationSummaryChain):
+ model_id = "internlm2-chat-20b"
+
class Claude2MKTConversationSummaryChain(Claude2ChatChain):
model_id = 'anthropic.claude-v2'
intent_type = MKT_CONVERSATION_SUMMARY_TYPE
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/query_rewrite_chain.py b/source/lambda/executor/utils/llm_utils/llm_chains/query_rewrite_chain.py
new file mode 100644
index 00000000..7badb312
--- /dev/null
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/query_rewrite_chain.py
@@ -0,0 +1,143 @@
+# query rewrite
+import json
+import os
+import re
+import sys
+from random import Random
+from functools import lru_cache
+
+from .llm_chain_base import LLMChain
+from ...constant import QUERY_REWRITE_TYPE
+from ..llm_models import Model as LLM_Model
+
+from langchain.prompts import PromptTemplate
+from langchain.schema.runnable import RunnablePassthrough, RunnableBranch, RunnableLambda
+from .chat_chain import Iternlm2Chat7BChatChain
+from ..llm_chains import LLMChain
+
+
+query_expansion_template_claude = PromptTemplate.from_template("""You are an AI language model assistant. Your task is to generate 1 - 5 different sub questions OR alternate versions of the given user question to retrieve relevant documents from a vector database.
+
+By generating multiple versions of the user question,
+your goal is to help the user overcome some of the limitations
+of distance-based similarity search.
+
+By generating sub questions, you can break down questions that refer to multiple concepts into distinct questions. This will help you get the relevant documents for constructing a final answer
+
+If multiple concepts are present in the question, you should break into sub questions, with one question for each concept
+
+Provide these alternative questions separated by newlines between XML tags. For example:
+
+
+- Question 1
+- Question 2
+- Question 3
+
+
+Original question: {question}""")
+
+
+class Claude2QueryRewriteChain(LLMChain):
+ model_id = 'anthropic.claude-v2'
+ intent_type = QUERY_REWRITE_TYPE
+
+ default_model_kwargs = {
+ "temperature": 0.7,
+ "max_tokens_to_sample": 100,
+ "stop_sequences": [
+ "\n\nHuman:"
+ ]
+ }
+
+ @staticmethod
+ def query_rewrite_postprocess(r):
+ ret = re.findall('.*?',r,re.S)[0]
+ questions = re.findall('- (.*?)\n',ret,re.S)
+ return questions
+
+ @classmethod
+ def create_chain(cls, model_kwargs=None, **kwargs):
+ query_key = kwargs.pop('query_key','query')
+ model_kwargs = model_kwargs or {}
+ model_kwargs = {**cls.default_model_kwargs,**model_kwargs}
+ llm = LLM_Model.get_model(cls.model_id, model_kwargs=model_kwargs,**kwargs)
+ chain = RunnableLambda(lambda x: query_expansion_template_claude.invoke({"question": x[query_key]})) | llm | RunnableLambda(cls.query_rewrite_postprocess)
+ return chain
+
+class Claude21QueryRewriteChain(Claude2QueryRewriteChain):
+ model_id = 'anthropic.claude-v2:1'
+
+
+class ClaudeInstanceQueryRewriteChain(Claude2QueryRewriteChain):
+ model_id = 'anthropic.claude-instant-v1'
+
+
+
+
+internlm2_meta_instruction = """You are an AI language model assistant. Your task is to generate 1 - 5 different sub questions OR alternate versions of the given user question to retrieve relevant documents from a vector database.
+
+By generating multiple versions of the user question,
+your goal is to help the user overcome some of the limitations
+of distance-based similarity search.
+
+By generating sub questions, you can break down questions that refer to multiple concepts into distinct questions. This will help you get the relevant documents for constructing a final answer
+
+If multiple concepts are present in the question, you should break into sub questions, with one question for each concept
+
+Provide these alternative questions separated by newlines between XML tags. For example:
+
+
+- Question 1
+- Question 2
+- Question 3
+"""
+
+class Iternlm2Chat7BQueryRewriteChain(Iternlm2Chat7BChatChain):
+ model_id = "internlm2-chat-7b"
+ intent_type = QUERY_REWRITE_TYPE
+
+ default_model_kwargs = {
+ "temperature":0.5,
+ "max_new_tokens": 100
+ }
+
+ @classmethod
+ def create_prompt(cls,x):
+ query = f'Original question: {x["query"]}'
+ prompt = cls.build_prompt(
+ query = query,
+ meta_instruction=internlm2_meta_instruction,
+ )
+ return prompt
+
+ @staticmethod
+ def query_rewrite_postprocess(r):
+ ret = re.findall('.*?',r,re.S)[0]
+ questions = re.findall('- (.*?)\n',ret,re.S)
+ return questions
+
+ @classmethod
+ def create_chain(cls, model_kwargs=None, **kwargs):
+ model_kwargs = model_kwargs or {}
+ model_kwargs = {**cls.default_model_kwargs,**model_kwargs}
+ chain = super().create_chain(
+ model_kwargs=model_kwargs,
+ **kwargs
+ )
+ chain = chain | RunnableLambda(lambda x:cls.query_rewrite_postprocess(x))
+ return chain
+
+
+class Iternlm2Chat20BQueryRewriteChain(Iternlm2Chat7BQueryRewriteChain):
+ model_id = "internlm2-chat-20b"
+ intent_type = QUERY_REWRITE_TYPE
+
+
+
+
+
+
+
+
+
+
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/rag_chain.py b/source/lambda/executor/utils/llm_utils/llm_chains/rag_chain.py
index 1cd9f56d..3b3ab64b 100644
--- a/source/lambda/executor/utils/llm_utils/llm_chains/rag_chain.py
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/rag_chain.py
@@ -52,7 +52,6 @@ class ClaudeRagInstance(Claude2RagLLMChain):
model_id = 'anthropic.claude-instant-v1'
-
from .chat_chain import Baichuan2Chat13B4BitsChatChain
class Baichuan2Chat13B4BitsKnowledgeQaChain(Baichuan2Chat13B4BitsChatChain):
@@ -82,7 +81,7 @@ class Iternlm2Chat7BKnowledgeQaChain(Iternlm2Chat7BChatChain):
mdoel_id = "internlm2-chat-7b"
intent_type = IntentType.KNOWLEDGE_QA.value
default_model_kwargs = {
- "temperature":0.1,
+ "temperature":0.05,
"max_new_tokens": 1000
}
@@ -90,16 +89,29 @@ class Iternlm2Chat7BKnowledgeQaChain(Iternlm2Chat7BChatChain):
def create_prompt(cls,x):
query = x['query']
contexts = x['contexts']
-
history = cls.create_history(x)
-
context = "\n".join(contexts)
- meta_instruction = f"你是一个Amazon AWS的客服助理,帮助用户回答使用AWS过程中的各种问题。面对用户的问题,你需要给出诚实并且有帮助的回答。\n背景知识: {context}\n"
- query = f"问题: {query}\n"
+ meta_instruction = f"你是一个Amazon AWS的客服助理小Q,帮助的用户回答使用AWS过程中的各种问题。\n面对用户的问题,你需要给出中文回答。\n下面给出相关问题的背景知识, 需要注意的是如果你认为当前的问题不能在背景知识中找到答案, 你需要拒答。\n 背景知识:\n{context}\n\n"
+ # meta_instruction = f"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use simplified Chinese to response the qustion. I’m going to tip $300K for a better answer! "
+ # meta_instruction = f'You are an expert AI on a question and answer task. \nUse the "Following Context" when answering the question. If you don't know the answer, reply to the "Following Text" in the header and answer to the best of your knowledge, or if you do know the answer, answer without the "Following Text"'
+# meta_instruction = """You are an expert AI on a question and answer task.
+# Use the "Following Context" when answering the question. If you don't know the answer, reply to the "Following Text" in the header and answer to the best of your knowledge, or if you do know the answer, answer without the "Following Text". If a question is asked in Korean, translate it to English and always answer in Korean.
+# Following Text: "I didn't find the answer in the context given, but here's what I know! **I could be wrong, so cross-verification is a must!**"""
+# meta_instruction = """You are an expert AI on a question and answer task.
+# Use the "Following Context" when answering the question. If you don't know the answer, reply to the "Sorry, I don't know". """
+ # query = f"Question: {query}\nContext:\n{context}"
+# query = f"""Following Context: {context}
+# Question: {query}"""
+ query = f"问题: {query}"
prompt = cls.build_prompt(
query=query,
history=history,
meta_instruction=meta_instruction
)
- prompt = prompt + "答案: 结合背景知识,经过深入的思考,我认为这个问题的答案是:"
- return prompt
\ No newline at end of file
+ # prompt = prompt + "回答: 让我先来判断一下问题的答案是否包含在背景知识中。"
+ prompt = prompt + f"回答: 经过慎重且深入的思考, 根据背景知识, 我的回答如下:\n"
+ print('internlm2 prompt: \n',prompt)
+ return prompt
+
+class Iternlm2Chat20BKnowledgeQaChain(Iternlm2Chat7BKnowledgeQaChain):
+ model_id = "internlm2-chat-20b"
\ No newline at end of file
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/rewrite_chain.py b/source/lambda/executor/utils/llm_utils/llm_chains/rewrite_chain.py
deleted file mode 100644
index 9a88663b..00000000
--- a/source/lambda/executor/utils/llm_utils/llm_chains/rewrite_chain.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# query rewrite
-from .llm_chain_base import LLMChain
-from ...constant import INTENT_RECOGNITION_TYPE,IntentType,QUERY_REWRITE_TYPE
-from ..llm_models import Model
-import json
-import os
-import sys
-from random import Random
-from functools import lru_cache
-from langchain.prompts import PromptTemplate
-from langchain.schema.runnable import RunnablePassthrough, RunnableBranch, RunnableLambda
-from .chat_chain import Iternlm2Chat7BChatChain
-
-
-
-class Iternlm2Chat7BIntentRecognitionChain(Iternlm2Chat7BChatChain):
- model_id = "internlm2-chat-7b"
- intent_type = QUERY_REWRITE_TYPE
-
- @classmethod
- def create_prompt(cls,x):
- raise NotImplementedError
-
- @classmethod
- def create_chain(cls, model_kwargs=None, **kwargs):
- raise NotImplementedError
\ No newline at end of file
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/stepback_chain.py b/source/lambda/executor/utils/llm_utils/llm_chains/stepback_chain.py
new file mode 100644
index 00000000..fbeec2c4
--- /dev/null
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/stepback_chain.py
@@ -0,0 +1,128 @@
+from langchain.schema.runnable import RunnableLambda
+from langchain.prompts import ChatPromptTemplate,FewShotChatMessagePromptTemplate,ChatMessagePromptTemplate
+from ...constant import STEPBACK_PROMPTING_TYPE
+from ..llm_chains.llm_chain_base import LLMChain
+from ..llm_models import Model
+from ..llm_chains.chat_chain import Iternlm2Chat7BChatChain
+
+
+class Iternlm2Chat7BStepBackChain(Iternlm2Chat7BChatChain):
+ model_id = "internlm2-chat-7b"
+ intent_type = STEPBACK_PROMPTING_TYPE
+
+ default_model_kwargs = {
+ "temperature":0.1,
+ "max_new_tokens": 200
+ }
+
+ @classmethod
+ def create_prompt(cls,x):
+ meta_instruction_template = "You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer. Here are a few examples: {few_examples}"
+ # meta_instruction_template = "你是一个拥有世界知识的专家. 你的任务是将问题转述为更通用的问题,这样更容易回答。更通用指的是将问题进行抽象表达,省略问题中的各种细节,包括具体时间,地点等。 下面有一些例子: {few_examples}"
+
+ few_examples = [
+ {
+ "input": "阿尔伯特-爱因斯坦的出生地是哪里?",
+ "output": "阿尔伯特-爱因斯坦的个人经历是怎样的?",
+ },
+ {
+ "input": "特斯拉在中国上海有多少门店",
+ "output": "特斯拉在中国的门店分布情况",
+ }
+ ]
+
+ few_examples_template = """origin question: {origin_question}
+ step-back question: {step_back_question}
+ """
+ few_examples_strs = []
+ for few_example in few_examples:
+ few_examples_strs.append(
+ few_examples_template.format(
+ origin_question=few_example['input'],
+ step_back_question=few_example['output']
+ ))
+ meta_instruction = meta_instruction_template.format(
+ few_examples="\n\n".join(few_examples_strs)
+ )
+ prompt = cls.build_prompt(
+ query=f"origin question: {x['query']}",
+ history=[],
+ meta_instruction=meta_instruction
+ ) + "step-back question: "
+ return prompt
+
+
+
+class Iternlm2Chat20BStepBackChain(Iternlm2Chat7BStepBackChain):
+ model_id = "internlm2-chat-20b"
+ intent_type = STEPBACK_PROMPTING_TYPE
+
+
+class Claude2StepBackChain(LLMChain):
+ model_id = 'anthropic.claude-v2'
+ intent_type = STEPBACK_PROMPTING_TYPE
+
+ @classmethod
+ def create_chain(cls, model_kwargs=None, **kwargs):
+ stream = kwargs.get('stream',False)
+ examples = [
+ {
+ "input": "Could the members of The Police perform lawful arrests?",
+ "output": "what can the members of The Police do?",
+ },
+ {
+ "input": "Jan Sindel’s was born in what country?",
+ "output": "what is Jan Sindel’s personal history?",
+ },
+ ]
+ # We now transform these to example messages
+ example_prompt = ChatPromptTemplate.from_messages(
+ [
+ ("human", "{input}"),
+ ("ai", "{output}"),
+ ]
+ )
+ few_shot_prompt = FewShotChatMessagePromptTemplate(
+ example_prompt=example_prompt,
+ examples=examples,
+ )
+
+ prompt = ChatPromptTemplate.from_messages(
+ [
+ (
+ "system",
+ """You are an expert at world knowledge. Your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer. Here are a few examples:""",
+ ),
+ # Few shot examples
+ few_shot_prompt,
+ # New question
+ ("user", "{query}"),
+ ])
+
+ kwargs.update({'return_chat_model':True})
+ llm = Model.get_model(
+ cls.model_id,
+ model_kwargs=model_kwargs,
+ **kwargs
+ )
+
+ chain = prompt | llm
+
+ if stream:
+ chain = prompt | RunnableLambda(lambda x: llm.stream(x.messages)) | RunnableLambda(lambda x:(i.content for i in x))
+ # llm_fn = RunnableLambda(llm.stream)
+ # postprocess_fn = RunnableLambda(cls.stream_postprocess)
+ else:
+ chain = prompt | llm | RunnableLambda(lambda x:x.dict()['content'])
+ # llm_fn = RunnableLambda(llm.predict)
+ # postprocess_fn = RunnableLambda(cls.api_postprocess)
+ return chain
+
+
+class Claude21StepBackChain(Claude2StepBackChain):
+ model_id = 'anthropic.claude-v2:1'
+ intent_type = STEPBACK_PROMPTING_TYPE
+
+class ClaudeInstanceStepBackChain(Claude2StepBackChain):
+ model_id = 'anthropic.claude-instant-v1'
+ intent_type = STEPBACK_PROMPTING_TYPE
\ No newline at end of file
diff --git a/source/lambda/executor/utils/llm_utils/llm_chains/translate_chain.py b/source/lambda/executor/utils/llm_utils/llm_chains/translate_chain.py
index af3f630a..bbd102fa 100644
--- a/source/lambda/executor/utils/llm_utils/llm_chains/translate_chain.py
+++ b/source/lambda/executor/utils/llm_utils/llm_chains/translate_chain.py
@@ -6,7 +6,7 @@
class Iternlm2Chat7BTranslateChain(Iternlm2Chat7BChatChain):
intent_type = QUERY_TRANSLATE_TYPE
default_model_kwargs = {
- "temperature": 0.0,
+ "temperature": 0.1,
"max_new_tokens": 200
}
@@ -15,8 +15,8 @@ def create_prompt(cls,x):
query = x['query']
target_lang = x['target_lang']
history = cls.create_history(x)
- meta_instruction = f'你是一个有经验的翻译助理, 正在将用户的问题翻译成{target_lang},请不要试图去回答用户的问题,仅仅做翻译。'
- query = f'请将文本:\n "{query}" \n 翻译成{target_lang}。\n 请直接翻译文本,不要输出多余的文本。'
+ meta_instruction = f'你是一个有经验的翻译助理, 正在将用户的问题翻译成{target_lang},不要试图去回答用户的问题,仅仅做翻译。'
+ query = f'将文本:\n "{query}" \n 翻译成{target_lang}。\n直接翻译文本,不要输出多余的文本。'
prompt = cls.build_prompt(
query=query,
@@ -31,4 +31,8 @@ def create_chain(cls, model_kwargs=None, **kwargs):
model_kwargs = {**cls.default_model_kwargs,**model_kwargs}
llm_chain = super().create_chain(model_kwargs=model_kwargs,**kwargs)
llm_chain = llm_chain | RunnableLambda(lambda x:x.strip('"')) # postprocess
- return llm_chain
\ No newline at end of file
+ return llm_chain
+
+
+class Iternlm2Chat20BTranslateChain(Iternlm2Chat7BTranslateChain):
+ model_id = "internlm2-chat-20b"
\ No newline at end of file
diff --git a/source/lambda/executor/utils/llm_utils/llm_models.py b/source/lambda/executor/utils/llm_utils/llm_models.py
index f192e869..343de37d 100644
--- a/source/lambda/executor/utils/llm_utils/llm_models.py
+++ b/source/lambda/executor/utils/llm_utils/llm_models.py
@@ -1,6 +1,7 @@
import boto3
import json
import os
+import logging
# from llmbot_utils import concat_recall_knowledge
from typing import Any, List, Mapping, Optional
@@ -13,6 +14,10 @@
from langchain_community.chat_models import BedrockChat
from langchain_community.llms.sagemaker_endpoint import LineIterator
from ..constant import HUMAN_MESSAGE_TYPE,AI_MESSAGE_TYPE,SYSTEM_MESSAGE_TYPE
+from ..logger_utils import logger
+
+logger = logging.getLogger("llm_model")
+logger.setLevel(logging.INFO)
class ModelMeta(type):
def __new__(cls, name, bases, attrs):
@@ -91,7 +96,7 @@ def __init__(self,model_kwargs=None,**kwargs) -> None:
self.model_kwargs = model_kwargs or {}
if self.default_model_kwargs is not None:
self.model_kwargs = {**self.default_model_kwargs,**self.model_kwargs}
-
+
self.region_name = kwargs.get('region_name',None) \
or os.environ.get('AWS_REGION', None) or None
self.kwargs = kwargs
@@ -153,7 +158,7 @@ class Baichuan2Chat13B4Bits(SagemakerModelBase):
"temperature": 0.3,
"top_k": 5,
"top_p": 0.85,
- "repetition_penalty": 1.05,
+ # "repetition_penalty": 1.05,
"do_sample": True,
"timeout":60
}
@@ -204,7 +209,7 @@ class Internlm2Chat7B(SagemakerModelBase):
default_model_kwargs = {
"max_new_tokens": 1024,
"timeout":60,
- 'repetition_penalty':1.2,
+ # 'repetition_penalty':1.05,
# "do_sample":True,
"temperature": 0.1,
"top_p": 0.8
@@ -222,6 +227,7 @@ def transform_input(self, x):
# assert user_message.type == HUMAN_MESSAGE_TYPE \
# and ai_message.type == AI_MESSAGE_TYPE , chat_history
# history.append((user_message.content,ai_message.content))
+ logger.info(f'prompt char num: {len(x["prompt"])}')
body = {
"query": x['prompt'],
# "meta_instruction": x.get('meta_instruction',self.meta_instruction),
@@ -229,5 +235,10 @@ def transform_input(self, x):
# "history": history
}
body.update(self.model_kwargs)
+ # print('body',body)
input_str = json.dumps(body)
- return input_str
\ No newline at end of file
+ return input_str
+
+
+class Internlm2Chat20B(Internlm2Chat7B):
+ model_id = "internlm2-chat-20b"
\ No newline at end of file
diff --git a/source/lambda/executor/utils/logger_utils.py b/source/lambda/executor/utils/logger_utils.py
index 0e44b6f3..c6d72493 100644
--- a/source/lambda/executor/utils/logger_utils.py
+++ b/source/lambda/executor/utils/logger_utils.py
@@ -8,3 +8,5 @@
opensearch_logger.setLevel(logging.ERROR)
boto3_logger = logging.getLogger("botocore")
boto3_logger.setLevel(logging.ERROR)
+
+
diff --git a/source/lambda/executor/utils/parse_config.py b/source/lambda/executor/utils/parse_config.py
index 9cd93010..da7137dc 100644
--- a/source/lambda/executor/utils/parse_config.py
+++ b/source/lambda/executor/utils/parse_config.py
@@ -1,7 +1,10 @@
import collections.abc
import copy
import logging
-from .constant import IntentType,AWS_TRANSLATE_SERVICE_MODEL_ID
+import os
+
+from .constant import AWS_TRANSLATE_SERVICE_MODEL_ID, IntentType, RerankerType
+
# update nest dict
def update_nest_dict(d, u):
@@ -17,91 +20,286 @@ def update_nest_dict(d, u):
rag_default_config = {
# retriver config
# query process config
- "retriever_config":{
- "retriever_top_k": 20,
+ "retriever_config": {
+ "retriever_top_k": 5,
"chunk_num": 2,
"using_whole_doc": False,
"reranker_top_k": 10,
- "enable_reranker": True,
- "q_q_match_threshold": 0.9
+ "reranker_type": RerankerType.BYPASS.value,
+ "q_q_match_threshold": 0.8,
+ "workspace_ids": [],
},
- "query_process_config":{
- "query_rewrite_config":{
- "model_id":"anthropic.claude-instant-v1",
- # "model_kwargs":{
- # "max_tokens_to_sample": 2000,
- # "temperature": 0.7,
- # "top_p": 0.9
- # }
+ "query_process_config": {
+ "query_rewrite_config": {
+ "model_id": "anthropic.claude-instant-v1",
},
- "conversation_query_rewrite_config":{
- "model_id":"anthropic.claude-instant-v1",
- # "model_kwargs":{
- # "max_tokens_to_sample": 2000,
- # "temperature": 0.7,
- # "top_p": 0.9
- # }
+ "conversation_query_rewrite_config": {
+ "model_id": "anthropic.claude-instant-v1",
},
- "hyde_config":{
- "model_id":"anthropic.claude-instant-v1",
- # "model_kwargs":{
- # "max_tokens_to_sample": 2000,
- # "temperature": 0.7,
- # "top_p": 0.9
- # }
+ "hyde_config": {
+ "model_id": "anthropic.claude-instant-v1",
},
- "translate_config":{
+ "stepback_config": {
+ "model_id": "anthropic.claude-instant-v1",
+ },
+ "translate_config": {
# default use Amazon Translate service
"model_id": AWS_TRANSLATE_SERVICE_MODEL_ID
- }
+ },
},
# intent_config
- "intent_config":{
- "intent_type":IntentType.KNOWLEDGE_QA.value,
- "model_id":"anthropic.claude-v2:1",
+ "intent_config": {
+ "intent_type": IntentType.KNOWLEDGE_QA.value,
+ "model_id": "anthropic.claude-v2:1",
# "model_kwargs":{"temperature":0,
# "max_tokens_to_sample": 2000,
# "stop_sequences": ["\n\n","\n\nHuman:"]
# },
- "sub_intent":{}
+ "sub_intent": {},
},
- # generator config
- "generator_llm_config":{
- "model_kwargs":{
+ # generator config
+ "generator_llm_config": {
+ "model_kwargs": {
# "max_tokens_to_sample": 2000,
# "temperature": 0.7,
# "top_p": 0.9
},
"model_id": "anthropic.claude-v2:1",
- "context_num": 2
+ "context_num": 2,
},
"mkt_conversation_summary_config": {
- "model_id":"anthropic.claude-v2:1",
+ "model_id": "anthropic.claude-v2:1",
},
"debug_level": logging.INFO,
"session_id": None,
- "ws_connection_id": None
+ "ws_connection_id": None,
+ "chat_history": None,
}
def parse_rag_config(event_body):
event_body = copy.deepcopy(event_body)
- new_event_config = update_nest_dict(
- copy.deepcopy(rag_default_config),
- event_body
- )
+ new_event_config = update_nest_dict(copy.deepcopy(rag_default_config), event_body)
# adapting before setting
temperature = event_body.get("temperature")
llm_model_id = event_body.get("llm_model_id")
if llm_model_id:
- new_event_config['generator_llm_config']['model_id'] = llm_model_id
+ new_event_config["generator_llm_config"]["model_id"] = llm_model_id
if temperature:
- new_event_config['generator_llm_config']['model_kwargs']['temperature'] = temperature
-
+ new_event_config["generator_llm_config"]["model_kwargs"][
+ "temperature"
+ ] = temperature
+
intent = event_body.get("intent", None) or event_body.get("model", None)
if intent:
- new_event_config['intent_config']['intent_type'] = intent
-
+ new_event_config["intent_config"]["intent_type"] = intent
+
+ return new_event_config
+
+
+def parse_mkt_entry_core_config(event_body):
+ return parse_rag_config(event_body)
+
+
+def parse_market_conversation_summary_entry_config(event_body):
+ event_body = copy.deepcopy(event_body)
+ llm_model_id = os.environ.get("llm_model_id")
+ llm_model_endpoint_name = os.environ.get("llm_model_endpoint_name")
+ region = os.environ.get("AWS_REGION")
+
+ is_cn_region = "cn" in region
+ llm_model_id = event_body.get("llm_model_id", llm_model_id)
+ llm_model_endpoint_name = event_body.get(
+ "llm_model_endpoint_name", llm_model_endpoint_name
+ )
+ assert llm_model_id and llm_model_endpoint_name, (
+ llm_model_id,
+ llm_model_endpoint_name,
+ )
+ default_config = {
+ "mkt_conversation_summary_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ }
+ }
+
+ new_event_config = update_nest_dict(copy.deepcopy(default_config), event_body)
+ return new_event_config
+
+
+def parse_mkt_entry_config(event_body):
+ event_body = copy.deepcopy(event_body)
+
+ llm_model_id = os.environ.get("llm_model_id")
+ llm_model_endpoint_name = os.environ.get("llm_model_endpoint_name")
+ region = os.environ.get("AWS_REGION")
+
+ is_cn_region = "cn" in region
+
+ # TODO modify rag_config
+ llm_model_id = event_body.get("llm_model_id", llm_model_id)
+ llm_model_endpoint_name = event_body.get(
+ "llm_model_endpoint_name", llm_model_endpoint_name
+ )
+ assert llm_model_id and llm_model_endpoint_name, (
+ llm_model_id,
+ llm_model_endpoint_name,
+ )
+
+ mkt_default_config = {
+ # retriver config
+ # query process config
+ "retriever_config": {
+ "retriever_top_k": 5,
+ "chunk_num": 2,
+ "using_whole_doc": False,
+ "reranker_top_k": 10,
+ "reranker_type": RerankerType.BYPASS.value,
+ "q_q_match_threshold": 0.9,
+ "workspace_ids": ["aos_index_mkt_faq_qq", "aos_index_acts_qd"],
+ },
+ "query_process_config": {
+ "query_rewrite_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ },
+ "conversation_query_rewrite_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ },
+ "hyde_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ },
+ "stepback_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ },
+ "translate_config": {
+ # default use Amazon Translate service
+ "model_id": (
+ llm_model_id if is_cn_region else AWS_TRANSLATE_SERVICE_MODEL_ID
+ ),
+ "endpoint_name": llm_model_endpoint_name,
+ },
+ },
+ # intent_config
+ "intent_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ "intent_type": IntentType.KNOWLEDGE_QA.value,
+ },
+ # generator config
+ "generator_llm_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ "context_num": 1,
+ },
+ }
+
+ new_event_config = update_nest_dict(copy.deepcopy(mkt_default_config), event_body)
+
+ intent = event_body.get("intent", None) or event_body.get("model", None)
+ if intent:
+ new_event_config["intent_config"]["intent_type"] = intent
+
+ return new_event_config
+
+def parse_mkt_entry_knowledge_config(event_body):
+ event_body = copy.deepcopy(event_body)
+
+ llm_model_id = os.environ.get("llm_model_id")
+ llm_model_endpoint_name = os.environ.get("llm_model_endpoint_name")
+ region = os.environ.get("AWS_REGION")
+
+ is_cn_region = "cn" in region
+
+ # TODO modify rag_config
+ llm_model_id = event_body.get("llm_model_id", llm_model_id)
+ llm_model_endpoint_name = event_body.get(
+ "llm_model_endpoint_name", llm_model_endpoint_name
+ )
+ assert llm_model_id and llm_model_endpoint_name, (
+ llm_model_id,
+ llm_model_endpoint_name,
+ )
+
+ mkt_default_config = {
+ # retriver config
+ # query process config
+ "retriever_config":{
+ "qq_config": {
+ "qq_match_threshold": 0.8,
+ "retriever_top_k": 5,
+ "query_key": "query"
+ },
+ "qd_config":{
+ "retriever_top_k": 5,
+ "context_num": 2,
+ "using_whole_doc": False,
+ "reranker_top_k": 10,
+ # "reranker_type": RerankerType.BYPASS.value,
+ "reranker_type": RerankerType.BGE_RERANKER.value,
+ # "reranker_type": RerankerType.BGE_M3_RERANKER.value,
+ "qd_match_threshold": 2,
+ "query_key":"conversation_query_rewrite"
+ # "enable_reranker":True
+ },
+ "workspace_ids": ["aos_index_mkt_faq_qq_m3", "aos_index_acts_qd_m3", "aos_index_mkt_faq_qd_m3", "aos_index_repost_qq_m3"],
+ # "retriever_top_k": 5,
+ # "chunk_num": 2,
+ # "using_whole_doc": False,
+ # "reranker_top_k": 10,
+ # "reranker_type": True,
+ # "q_q_match_threshold": 0.9,
+ # "qd_match_threshold": -1
+ },
+ "query_process_config": {
+ "query_rewrite_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name
+ },
+ "conversation_query_rewrite_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ "result_key": "conversation_query_rewrite"
+ },
+ "hyde_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ },
+ "stepback_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ },
+ "translate_config": {
+ # default use Amazon Translate service
+ "model_id": (
+ llm_model_id if is_cn_region else AWS_TRANSLATE_SERVICE_MODEL_ID
+ ),
+ "endpoint_name": llm_model_endpoint_name,
+ },
+ },
+ # intent_config
+ "intent_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ },
+ # generator config
+ "generator_llm_config": {
+ "model_id": llm_model_id,
+ "endpoint_name": llm_model_endpoint_name,
+ "context_num": 1,
+ },
+ "use_history": False
+ }
+
+ new_event_config = update_nest_dict(copy.deepcopy(mkt_default_config), event_body)
+
+ intent = event_body.get("intent", None) or event_body.get("model", None)
+ if intent:
+ new_event_config["intent_config"]["intent_type"] = intent
+
return new_event_config
diff --git a/source/lambda/executor/utils/prompt_template.py b/source/lambda/executor/utils/prompt_template.py
index 46f982bc..1bc2eb39 100644
--- a/source/lambda/executor/utils/prompt_template.py
+++ b/source/lambda/executor/utils/prompt_template.py
@@ -233,9 +233,9 @@ def get_conversation_query_rewrite_prompt(chat_history:List[BaseMessage]):
####### hyde prompt ###############
-WEB_SEARCH_TEMPLATE = """Please write a passage to answer the question
-Question: {query}
-Passage:"""
-HYDE_WEB_SEARCH_TEMPLATE = PromptTemplate(template=WEB_SEARCH_TEMPLATE, input_variables=["query"])
+# WEB_SEARCH_TEMPLATE = """Please write a passage to answer the question
+# Question: {query}
+# Passage:"""
+# HYDE_WEB_SEARCH_TEMPLATE = PromptTemplate(template=WEB_SEARCH_TEMPLATE, input_variables=["query"])
diff --git a/source/lambda/executor/utils/query_process_utils.py b/source/lambda/executor/utils/query_process_utils.py
index 805e0bf6..d5cc8354 100644
--- a/source/lambda/executor/utils/query_process_utils.py
+++ b/source/lambda/executor/utils/query_process_utils.py
@@ -3,27 +3,27 @@
from .llm_utils import Model as LLM_Model
from .llm_utils import LLMChain
from langchain.schema.runnable import RunnableLambda,RunnablePassthrough
-from .prompt_template import get_conversation_query_rewrite_prompt, HYDE_WEB_SEARCH_TEMPLATE as hyde_web_search_template
+# from .prompt_template import get_conversation_query_rewrite_prompt as hyde_web_search_template
from .langchain_utils import chain_logger
from .preprocess_utils import is_api_query, language_check,query_translate,get_service_name
# from langchain.memory import ConversationSummaryMemory, ChatMessageHistory
-from .constant import CONVERSATION_SUMMARY_TYPE
-
-def query_rewrite_postprocess(r):
- ret = re.findall('.*?',r,re.S)[0]
- questions = re.findall('- (.*?)\n',ret,re.S)
- return questions
-
-def get_query_rewrite_chain(
- model_id,
- model_kwargs=None,
- query_expansion_template="hwchase17/multi-query-retriever",
- query_key='query'
- ):
- query_expansion_template = hub.pull(query_expansion_template)
- llm = LLM_Model.get_model(model_id=model_id, model_kwargs=model_kwargs)
- chain = RunnableLambda(lambda x: query_expansion_template.invoke({"question": x[query_key]})) | llm | RunnableLambda(query_rewrite_postprocess)
- return chain
+from .constant import CONVERSATION_SUMMARY_TYPE,STEPBACK_PROMPTING_TYPE,HYDE_TYPE,QUERY_REWRITE_TYPE
+
+# def query_rewrite_postprocess(r):
+# ret = re.findall('.*?',r,re.S)[0]
+# questions = re.findall('- (.*?)\n',ret,re.S)
+# return questions
+
+# def get_query_rewrite_chain(
+# model_id,
+# model_kwargs=None,
+# query_expansion_template="hwchase17/multi-query-retriever",
+# query_key='query'
+# ):
+# query_expansion_template = hub.pull(query_expansion_template)
+# llm = LLM_Model.get_model(model_id=model_id, model_kwargs=model_kwargs)
+# chain = RunnableLambda(lambda x: query_expansion_template.invoke({"question": x[query_key]})) | llm | RunnableLambda(query_rewrite_postprocess)
+# return chain
def get_conversation_query_rewrite_chain(
chat_history:list,
@@ -32,7 +32,6 @@ def get_conversation_query_rewrite_chain(
# single turn
if not chat_history:
return RunnableLambda(lambda x:x['query'])
-
cqr_chain = LLMChain.get_chain(
intent_type=CONVERSATION_SUMMARY_TYPE,
**conversation_query_rewrite_config
@@ -40,41 +39,42 @@ def get_conversation_query_rewrite_chain(
return cqr_chain
-def get_hyde_chain(
- model_id,
- model_kwargs=None,
- query_key='query'
- ):
- llm = LLM_Model.get_model(
- model_id=model_id,
- model_kwargs=model_kwargs,
- return_chat_model=False
- )
- chain = RunnablePassthrough.assign(
- hyde_doc = RunnableLambda(lambda x: hyde_web_search_template.invoke({"query": x[query_key]})) | llm
- )
-
- return chain
+# def get_hyde_chain(
+# model_id,
+# model_kwargs=None,
+# query_key='query'
+# ):
+# llm = LLM_Model.get_model(
+# model_id=model_id,
+# model_kwargs=model_kwargs,
+# return_chat_model=False
+# )
+# chain = RunnablePassthrough.assign(
+# hyde_doc = RunnableLambda(lambda x: hyde_web_search_template.invoke({"query": x[query_key]})) | llm
+# )
+# return chain
def get_query_process_chain(
chat_history,
query_process_config,
+ message_id=None
):
query_rewrite_config = query_process_config['query_rewrite_config']
conversation_query_rewrite_config = query_process_config['conversation_query_rewrite_config']
hyde_config = query_process_config['hyde_config']
translate_config = query_process_config['translate_config']
- query_rewrite_chain = get_query_rewrite_chain(
- # llm_model_id = query_rewrite_config['model_id'],
- # model_kwargs = query_rewrite_config['model_kwargs'],
- query_key='conversation_query_rewrite',
+ query_rewrite_chain = RunnablePassthrough.assign(
+ query_rewrite = LLMChain.get_chain(
+ query_key='query',
+ intent_type=QUERY_REWRITE_TYPE,
**query_rewrite_config
- )
+ ))
query_rewrite_chain = chain_logger(
query_rewrite_chain,
'query rewrite module',
- log_output_template='query_rewrite result: {query_rewrite}.'
+ log_output_template='query_rewrite result: {query_rewrite}.',
+ message_id=message_id
)
conversation_query_rewrite_chain = RunnablePassthrough.assign(
@@ -88,7 +88,8 @@ def get_query_process_chain(
conversation_query_rewrite_chain = chain_logger(
conversation_query_rewrite_chain,
"conversation query rewrite module",
- log_output_template='conversation_query_rewrite result: {conversation_query_rewrite}.'
+ log_output_template='conversation_query_rewrite result: {conversation_query_rewrite}.',
+ message_id=message_id
)
preprocess_chain = RunnablePassthrough.assign(
@@ -105,28 +106,48 @@ def get_query_process_chain(
preprocess_chain = chain_logger(
preprocess_chain,
'preprocess module',
- log_output_template='\nquery lang:{query_lang},\nquery translated: {translated_text}'
+ log_output_template='\nquery lang:{query_lang},\nquery translated: {translated_text}',
+ message_id=message_id
)
- hyde_chain = get_hyde_chain(
- **hyde_config
-
- )
+ hyde_chain = RunnablePassthrough.assign(
+ hyde_doc = LLMChain.get_chain(
+ intent_type=HYDE_TYPE,
+ query_key='query',
+ **hyde_config
+ )
+ )
hyde_chain = chain_logger(
hyde_chain,
"hyde chain",
- log_output_template="\nhyde generate passage: {hyde_doc}"
+ log_output_template="\nhyde generate passage: {hyde_doc}",
+ message_id=message_id
+ )
+
+ stepback_promping_chain = RunnablePassthrough.assign(
+ stepback_query = LLMChain.get_chain(
+ intent_type=STEPBACK_PROMPTING_TYPE,
+ **query_process_config['stepback_config']
+ )
+ )
+
+ stepback_promping_chain = chain_logger(
+ stepback_promping_chain,
+ "stepback promping chain",
+ log_output_template="stepback_promping_chain query: {stepback_query}",
+ message_id=message_id
)
#
query_process_chain = preprocess_chain
- query_process_chain = conversation_query_rewrite_chain | preprocess_chain
+ query_process_chain = conversation_query_rewrite_chain | preprocess_chain # | stepback_promping_chain
-
+
query_process_chain = chain_logger(
query_process_chain,
- "query process module"
+ "query process module",
+ message_id=message_id
)
return query_process_chain
diff --git a/source/lambda/executor/utils/reranker.py b/source/lambda/executor/utils/reranker.py
index 48c07f23..30278bca 100644
--- a/source/lambda/executor/utils/reranker.py
+++ b/source/lambda/executor/utils/reranker.py
@@ -3,6 +3,7 @@
import time
import logging
import asyncio
+import numpy as np
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@@ -17,11 +18,92 @@
rerank_model_endpoint = os.environ.get("rerank_endpoint", "")
region = os.environ["AWS_REGION"]
+"""Document compressor that uses BGE reranker model."""
+class BGEM3Reranker(BaseDocumentCompressor):
+
+ """Number of documents to return."""
+ # top_n: int = 3
+ def _colbert_score_np(q_reps, p_reps):
+ token_scores = np.einsum('in,jn->ij', q_reps, p_reps)
+ scores = token_scores.max(-1)
+ scores = np.sum(scores) / q_reps.shape[0]
+ return scores
+
+ async def __ainvoke_rerank_model(self, batch, loop):
+ # await asyncio.sleep(2)
+ return await loop.run_in_executor(None,
+ _colbert_score_np,
+ batch[0],
+ batch[1])
+
+ async def __spawn_task(self, rerank_pair):
+ batch_size = 1
+ task_list = []
+ loop = asyncio.get_event_loop()
+ for batch_start in range(0, len(rerank_pair), batch_size):
+ task = asyncio.create_task(self.__ainvoke_rerank_model(rerank_pair[batch_start:batch_start + batch_size], loop))
+ task_list.append(task)
+ return await asyncio.gather(*task_list)
+
+ def compress_documents(
+ self,
+ documents: Sequence[Document],
+ query: dict,
+ callbacks: Optional[Callbacks] = None,
+ ) -> Sequence[Document]:
+ """
+ Compress documents using BGE M3 Colbert Score.
+
+ Args:
+ documents: A sequence of documents to compress.
+ query: The query to use for compressing the documents.
+ callbacks: Callbacks to run during the compression process.
+
+ Returns:
+ A sequence of compressed documents.
+ """
+ start = time.time()
+ if len(documents) == 0: # to avoid empty api call
+ return []
+ doc_list = list(documents)
+ _docs = [d.metadata["retrieval_content"]['colbert'] for d in doc_list]
+ # _docs = [d.page_content for d in doc_list]
+
+ rerank_pair = []
+ rerank_text_length = 1024 * 10
+ for doc in _docs:
+ rerank_pair.append([query["query"], doc[:rerank_text_length]])
+ score_list = []
+ logger.info(f'rerank pair num {len(rerank_pair)}, m3 method: colbert score')
+ response_list = asyncio.run(self.__spawn_task(rerank_pair))
+ for response in response_list:
+ score_list.extend(response)
+ final_results = []
+ debug_info = query["debug_info"]
+ debug_info["knowledge_qa_rerank"] = []
+ for doc, score in zip(doc_list, score_list):
+ doc.metadata["rerank_score"] = score
+ # set common score for llm.
+ doc.metadata["score"] = doc.metadata["rerank_score"]
+ final_results.append(doc)
+ debug_info["knowledge_qa_rerank"].append((doc.page_content, doc.metadata["retrieval_content"], doc.metadata["source"], score))
+ final_results.sort(key=lambda x: x.metadata["rerank_score"], reverse=True)
+ debug_info["knowledge_qa_rerank"].sort(key=lambda x: x[-1], reverse=True)
+ recall_end_time = time.time()
+ elpase_time = recall_end_time - start
+ logger.info(f"runing time of rerank: {elpase_time}s seconds")
+ return final_results
+
"""Document compressor that uses BGE reranker model."""
class BGEReranker(BaseDocumentCompressor):
"""Number of documents to return."""
- top_n: int = 3
+ # top_n: int = 3
+ query_key: str="query"
+
+ def __init__(self, query_key='query'):
+ super().__init__()
+ self.query_key = query_key
async def __ainvoke_rerank_model(self, batch, loop):
# await asyncio.sleep(2)
@@ -69,7 +151,7 @@ def compress_documents(
rerank_pair = []
rerank_text_length = 1024 * 10
for doc in _docs:
- rerank_pair.append([query["query"], doc[:rerank_text_length]])
+ rerank_pair.append([query[self.query_key], doc[:rerank_text_length]])
score_list = []
logger.info(f'rerank pair num {len(rerank_pair)}, endpoint_name: {rerank_model_endpoint}')
response_list = asyncio.run(self.__spawn_task(rerank_pair))
@@ -95,7 +177,7 @@ def compress_documents(
class MergeReranker(BaseDocumentCompressor):
"""Number of documents to return."""
- top_n: int = 3
+ # top_n: int = 3
def compress_documents(
self,
diff --git a/source/lambda/executor/utils/response_utils.py b/source/lambda/executor/utils/response_utils.py
index c4332ded..74f491ab 100644
--- a/source/lambda/executor/utils/response_utils.py
+++ b/source/lambda/executor/utils/response_utils.py
@@ -1,21 +1,17 @@
import copy
+import csv
+import os
import json
import logging
import time
import traceback
+from .constant import EntryType,StreamMessageType
+from .content_filter_utils.content_filters import token_to_sentence_gen_market, MarketContentFilter
-from .constant import EntryType
-
-logger = logging.getLogger()
-
-
-class StreamMessageType:
- START = "START"
- END = "END"
- ERROR = "ERROR"
- CHUNK = "CHUNK"
- CONTEXT = "CONTEXT"
+logger = logging.getLogger("response_utils")
+# marketing
+market_content_filter = MarketContentFilter()
class WebsocketClientError(Exception):
pass
@@ -33,20 +29,27 @@ def api_response(**kwargs):
contexts = kwargs["contexts"]
enable_debug = kwargs["enable_debug"]
debug_info = kwargs["debug_info"]
- chat_history = kwargs["chat_history"]
+ ddb_history_obj = kwargs["ddb_history_obj"]
message_id = kwargs["message_id"]
question = kwargs["question"]
+ client_type = kwargs["client_type"]
+ custom_message_id = kwargs["custom_message_id"]
if not isinstance(answer, str):
answer = json.dumps(answer, ensure_ascii=False)
if entry_type != EntryType.MARKET_CONVERSATION_SUMMARY.value:
- chat_history.add_user_message(f"user_{message_id}", question, entry_type)
- chat_history.add_ai_message(f"ai_{message_id}", answer, entry_type)
+ ddb_history_obj.add_user_message(
+ question, f"user_{message_id}", custom_message_id, entry_type
+ )
+ ddb_history_obj.add_ai_message(
+ answer, f"ai_{message_id}", custom_message_id, entry_type
+ )
# 2. return rusult
llmbot_response = {
- "id": session_id,
+ "session_id": session_id,
+ "client_type": client_type,
"object": "chat.completion",
"created": int(request_timestamp),
# "model": model,
@@ -59,6 +62,7 @@ def api_response(**kwargs):
"knowledge_sources": sources,
},
"message_id": f"ai_{message_id}",
+ "custom_message_id": custom_message_id,
"finish_reason": "stop",
"index": 0,
}
@@ -94,14 +98,18 @@ def stream_response(**kwargs):
enable_debug = kwargs["enable_debug"]
debug_info = kwargs["debug_info"]
ws_client = kwargs["ws_client"]
- chat_history = kwargs["chat_history"]
+ ddb_history_obj = kwargs["ddb_history_obj"]
message_id = kwargs["message_id"]
question = kwargs["question"]
entry_type = kwargs["entry_type"]
ws_connection_id = kwargs["ws_connection_id"]
+ log_first_token_time = kwargs.get("log_first_token_time", True)
+ client_type = kwargs["client_type"]
+ custom_message_id = kwargs["custom_message_id"]
+ main_entry_end = kwargs["main_entry_end"]
if isinstance(answer, str):
- answer = [answer]
+ answer = iter([answer])
def _stop_stream():
pass
@@ -111,15 +119,10 @@ def _stop_stream():
def _send_to_ws_client(message: dict):
try:
llmbot_response = {
- "id": session_id,
+ "session_id": session_id,
+ "client_type": client_type,
"object": "chat.completion",
"created": int(request_timestamp),
- # "model": '',
- # "usage": {
- # "prompt_tokens": 13,
- # "completion_tokens": 7,
- # "total_tokens": 20,
- # },
"choices": [message],
"entry_type": entry_type,
}
@@ -136,32 +139,56 @@ def _send_to_ws_client(message: dict):
{
"message_type": StreamMessageType.START,
"message_id": f"ai_{message_id}",
+ "custom_message_id": custom_message_id,
}
)
answer_str = ""
- for i, ans in enumerate(answer):
+
+ for i, chunk in enumerate(token_to_sentence_gen_market(answer)):
+ if i == 0 and log_first_token_time:
+ first_token_time = time.time()
+ logger.info(
+ f"{custom_message_id} running time of first token generated {entry_type} : {first_token_time-main_entry_end}s"
+ )
+ logger.info(
+ f"{custom_message_id} running time of first token whole {entry_type} : {first_token_time-request_timestamp}s"
+ )
+ chunk = market_content_filter.filter_sentence(chunk)
_send_to_ws_client(
{
"message_type": StreamMessageType.CHUNK,
"message_id": f"ai_{message_id}",
+ "custom_message_id": custom_message_id,
"message": {
"role": "assistant",
- "content": ans,
+ "content": chunk,
# "knowledge_sources": sources,
},
"chunk_id": i,
}
)
- answer_str += ans
+ answer_str += chunk
+
+ if log_first_token_time:
+ logger.info(
+ f"{custom_message_id} running time of last token whole {entry_type} : {time.time()-request_timestamp}s"
+ )
+
+ logger.info(f'answer: {answer_str}')
# add to chat history ddb table
if entry_type != EntryType.MARKET_CONVERSATION_SUMMARY.value:
- chat_history.add_user_message(f"user_{message_id}", question, entry_type)
- chat_history.add_ai_message(f"ai_{message_id}", answer_str, entry_type)
+ ddb_history_obj.add_user_message(
+ question, f"user_{message_id}", custom_message_id, entry_type
+ )
+ ddb_history_obj.add_ai_message(
+ answer_str, f"ai_{message_id}", custom_message_id, entry_type
+ )
# sed source and contexts
context_msg = {
"message_type": StreamMessageType.CONTEXT,
"message_id": f"ai_{message_id}",
+ "custom_message_id": custom_message_id,
"knowledge_sources": sources,
}
if get_contexts:
@@ -177,6 +204,7 @@ def _send_to_ws_client(message: dict):
{
"message_type": StreamMessageType.END,
"message_id": f"ai_{message_id}",
+ "custom_message_id": custom_message_id,
}
)
except WebsocketClientError:
@@ -191,6 +219,7 @@ def _send_to_ws_client(message: dict):
{
"message_type": StreamMessageType.ERROR,
"message_id": f"ai_{message_id}",
+ "custom_message_id": custom_message_id,
"message": {"content": error},
}
)
diff --git a/source/lambda/executor/utils/retriever.py b/source/lambda/executor/utils/retriever.py
index 3506babb..6918c1e8 100644
--- a/source/lambda/executor/utils/retriever.py
+++ b/source/lambda/executor/utils/retriever.py
@@ -45,40 +45,53 @@ def remove_redundancy_debug_info(results):
def get_similarity_embedding(
query: str,
embedding_model_endpoint: str,
+ model_type: str = "vector"
):
query_similarity_embedding_prompt = query
- query_embedding = SagemakerEndpointVectorOrCross(
+ response = SagemakerEndpointVectorOrCross(
prompt=query_similarity_embedding_prompt,
endpoint_name=embedding_model_endpoint,
region_name=region,
- model_type="vector",
+ model_type=model_type,
stop=None,
)
- return query_embedding
+ if model_type == "vector":
+ response = {"dense_vecs": response}
+ elif model_type == "m3":
+ response["dense_vecs"] = response["dense_vecs"][0]
+ return response
@timeit
def get_relevance_embedding(
query: str,
query_lang: str,
embedding_model_endpoint: str,
+ model_type: str = "vector"
):
- if query_lang == "zh":
- query_relevance_embedding_prompt = (
- "为这个句子生成表示以用于检索相关文章:" + query
+ if model_type == "vector":
+ if query_lang == "zh":
+ query_relevance_embedding_prompt = (
+ "为这个句子生成表示以用于检索相关文章:" + query
+ )
+ elif query_lang == "en":
+ query_relevance_embedding_prompt = (
+ "Represent this sentence for searching relevant passages: "
+ + query
)
- elif query_lang == "en":
- query_relevance_embedding_prompt = (
- "Represent this sentence for searching relevant passages: "
- + query
- )
- query_embedding = SagemakerEndpointVectorOrCross(
+ elif model_type == "m3":
+ query_relevance_embedding_prompt = query
+ response = SagemakerEndpointVectorOrCross(
prompt=query_relevance_embedding_prompt,
endpoint_name=embedding_model_endpoint,
region_name=region,
- model_type="vector",
+ model_type=model_type,
stop=None,
)
- return query_embedding
+ if model_type == "vector":
+ response = {"dense_vecs": response}
+ elif model_type == "m3":
+ response["dense_vecs"] = response["dense_vecs"][0]
+ return response
def get_filter_list(parsed_query: dict):
filter_list = []
@@ -141,11 +154,20 @@ def get_doc(file_path, index_name):
chunk_text_list = [x[4] for x in sorted_chunk_list]
return "\n".join(chunk_text_list)
-def get_context(previous_chunk_id, next_chunk_id, index_name, window_size):
+def get_inner_context(chunk_id, index_name, window_size):
+ next_content_list = []
previous_content_list = []
previous_pos = 0
next_pos = 0
- while previous_chunk_id and previous_chunk_id.startswith("$") and previous_pos < window_size:
+ chunk_id_prefix = "-".join(chunk_id.split("-")[:-1])
+ section_id = int(chunk_id.split("-")[-1])
+ previous_section_id = section_id
+ next_section_id = section_id
+ while previous_pos < window_size:
+ previous_section_id -= 1
+ if previous_section_id < 1:
+ break
+ previous_chunk_id = f"{chunk_id_prefix}-{previous_section_id}"
opensearch_query_response = aos_client.search(
index_name=index_name,
query_type="basic",
@@ -155,13 +177,58 @@ def get_context(previous_chunk_id, next_chunk_id, index_name, window_size):
)
if len(opensearch_query_response["hits"]["hits"]) > 0:
r = opensearch_query_response["hits"]["hits"][0]
- previous_chunk_id = r["_source"]["metadata"]["heading_hierarchy"]["previous"]
previous_content_list.insert(0, r["_source"]["text"])
previous_pos += 1
else:
break
+ while next_pos < window_size:
+ next_section_id += 1
+ next_chunk_id = f"{chunk_id_prefix}-{next_section_id}"
+ opensearch_query_response = aos_client.search(
+ index_name=index_name,
+ query_type="basic",
+ query_term=next_chunk_id,
+ field="metadata.chunk_id",
+ size=1,
+ )
+ if len(opensearch_query_response["hits"]["hits"]) > 0:
+ r = opensearch_query_response["hits"]["hits"][0]
+ next_content_list.insert(0, r["_source"]["text"])
+ next_pos += 1
+ else:
+ break
+ return [previous_content_list, next_content_list]
+
+def get_sibling_context(chunk_id, index_name, window_size):
next_content_list = []
- while next_chunk_id and next_chunk_id.startswith("$") and next_pos < window_size:
+ previous_content_list = []
+ previous_pos = 0
+ next_pos = 0
+ chunk_id_prefix = "-".join(chunk_id.split("-")[:-1])
+ section_id = int(chunk_id.split("-")[-1])
+ previous_section_id = section_id
+ next_section_id = section_id
+ while previous_pos < window_size:
+ previous_section_id -= 1
+ if previous_section_id < 1:
+ break
+ previous_chunk_id = f"{chunk_id_prefix}-{previous_section_id}"
+ opensearch_query_response = aos_client.search(
+ index_name=index_name,
+ query_type="basic",
+ query_term=previous_chunk_id,
+ field="metadata.chunk_id",
+ size=1,
+ )
+ if len(opensearch_query_response["hits"]["hits"]) > 0:
+ r = opensearch_query_response["hits"]["hits"][0]
+ previous_content_list.insert(0, r["_source"]["text"])
+ previous_pos += 1
+ else:
+ break
+ while next_pos < window_size:
+ next_section_id += 1
+ next_chunk_id = f"{chunk_id_prefix}-{next_section_id}"
opensearch_query_response = aos_client.search(
index_name=index_name,
query_type="basic",
@@ -171,13 +238,62 @@ def get_context(previous_chunk_id, next_chunk_id, index_name, window_size):
)
if len(opensearch_query_response["hits"]["hits"]) > 0:
r = opensearch_query_response["hits"]["hits"][0]
- next_chunk_id = r["_source"]["metadata"]["heading_hierarchy"]["next"]
- next_content_list.append(r["_source"]["text"])
+ next_content_list.insert(0, r["_source"]["text"])
next_pos += 1
else:
break
return [previous_content_list, next_content_list]
+def get_context(aos_hit, index_name, window_size):
+ previous_content_list = []
+ next_content_list = []
+ if "chunk_id" not in aos_hit['_source']["metadata"]:
+ return previous_content_list, next_content_list
+ chunk_id = aos_hit["_source"]["metadata"]["chunk_id"]
+ inner_previous_content_list, inner_next_content_list = get_sibling_context(chunk_id, index_name, window_size)
+ if len(inner_previous_content_list) == window_size and len(inner_next_content_list) == window_size:
+ return inner_previous_content_list, inner_next_content_list
+
+ if "heading_hierarchy" not in aos_hit['_source']["metadata"]:
+ return [previous_content_list, next_content_list]
+ if "previous" in aos_hit['_source']["metadata"]["heading_hierarchy"]:
+ previous_chunk_id = aos_hit['_source']["metadata"]["heading_hierarchy"]["previous"]
+ previous_pos = 0
+ while previous_chunk_id and previous_chunk_id.startswith("$") and previous_pos < window_size:
+ opensearch_query_response = aos_client.search(
+ index_name=index_name,
+ query_type="basic",
+ query_term=previous_chunk_id,
+ field="metadata.chunk_id",
+ size=1,
+ )
+ if len(opensearch_query_response["hits"]["hits"]) > 0:
+ r = opensearch_query_response["hits"]["hits"][0]
+ previous_chunk_id = r["_source"]["metadata"]["heading_hierarchy"]["previous"]
+ previous_content_list.insert(0, r["_source"]["text"])
+ previous_pos += 1
+ else:
+ break
+ if "next" in aos_hit['_source']["metadata"]["heading_hierarchy"]:
+ next_chunk_id = aos_hit['_source']["metadata"]["heading_hierarchy"]["next"]
+ next_pos = 0
+ while next_chunk_id and next_chunk_id.startswith("$") and next_pos < window_size:
+ opensearch_query_response = aos_client.search(
+ index_name=index_name,
+ query_type="basic",
+ query_term=next_chunk_id,
+ field="metadata.chunk_id",
+ size=1,
+ )
+ if len(opensearch_query_response["hits"]["hits"]) > 0:
+ r = opensearch_query_response["hits"]["hits"][0]
+ next_chunk_id = r["_source"]["metadata"]["heading_hierarchy"]["next"]
+ next_content_list.append(r["_source"]["text"])
+ next_pos += 1
+ else:
+ break
+ return [previous_content_list, next_content_list]
+
def get_parent_content(previous_chunk_id, next_chunk_id, index_name):
previous_content_list = []
while previous_chunk_id.startswith("$"):
@@ -257,35 +373,37 @@ class QueryQuestionRetriever(BaseRetriever):
size: Any
lang: Any
embedding_model_endpoint: Any
+ model_type: Any
+ query_key: str= "query"
- def __init__(self, index: str, vector_field: str, source_field: str,
- size: float, lang: str, embedding_model_endpoint: str):
+ def __init__(self, workspace:Dict, size: int,query_key="query"):
super().__init__()
- self.index = index
- self.vector_field = vector_field
- self.source_field = source_field
+ self.index = workspace["open_search_index_name"]
+ self.vector_field = "vector_field"
+ self.source_field = "file_path"
self.size = size
- self.lang = lang
- self.embedding_model_endpoint = embedding_model_endpoint
+ self.lang = workspace["languages"][0]
+ self.embedding_model_endpoint = workspace["embeddings_model_endpoint"]
+ self.model_type = workspace["model_type"]
+ self.query_key = query_key
@timeit
def _get_relevant_documents(self, question: Dict, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
- query = question["query"]
+ query = question[self.query_key]
debug_info = question["debug_info"]
- start = time.time()
opensearch_knn_results = []
- query_embedding = get_similarity_embedding(query, self.embedding_model_endpoint)
+ query_repr = get_similarity_embedding(query, self.embedding_model_endpoint, self.model_type)
opensearch_knn_response = aos_client.search(
index_name=self.index,
query_type="knn",
- query_term=query_embedding,
+ query_term=query_repr["dense_vecs"],
field=self.vector_field,
size=self.size,
)
opensearch_knn_results.extend(
organize_faq_results(opensearch_knn_response, self.index, self.source_field)
)
- debug_info[f"q_q_match_info_{self.index}_{self.lang}"] = remove_redundancy_debug_info(opensearch_knn_results)
+ debug_info[f"qq-knn-recall-{self.index}-{self.lang}"] = remove_redundancy_debug_info(opensearch_knn_results)
docs = []
for result in opensearch_knn_results:
docs.append(Document(page_content=result["content"], metadata={
@@ -302,26 +420,28 @@ class QueryDocumentRetriever(BaseRetriever):
context_num: Any
top_k: Any
lang: Any
+ model_type: Any
embedding_model_endpoint: Any
+ query_key: str="query"
- def __init__(self, index, vector_field, text_field, source_field, using_whole_doc,
- context_num, top_k, lang, embedding_model_endpoint):
+ def __init__(self, workspace, using_whole_doc, context_num, top_k,query_key='query'):
super().__init__()
- self.index = index
- self.vector_field = vector_field
- self.text_field = text_field
- self.source_field = source_field
+ self.index = workspace["open_search_index_name"]
+ self.vector_field = "vector_field"
+ self.source_field = "file_path"
+ self.text_field = "text"
+ self.lang = workspace["languages"][0]
+ self.embedding_model_endpoint = workspace["embeddings_model_endpoint"]
+ self.model_type = workspace["model_type"]
self.using_whole_doc = using_whole_doc
self.context_num = context_num
self.top_k = top_k
- self.lang = lang
- self.embedding_model_endpoint = embedding_model_endpoint
+ self.query_key = query_key
- async def __ainvoke_get_context(self, previous_chunk_id, next_chunk_id, window_size, loop):
+ async def __ainvoke_get_context(self, aos_hit, window_size, loop):
return await loop.run_in_executor(None,
get_context,
- previous_chunk_id,
- next_chunk_id,
+ aos_hit,
self.index,
window_size)
@@ -329,16 +449,13 @@ async def __spawn_task(self, aos_hits, context_size):
loop = asyncio.get_event_loop()
task_list = []
for aos_hit in aos_hits:
- if context_size and ("heading_hierarchy" in aos_hit['_source']["metadata"] and
- "previous" in aos_hit['_source']["metadata"]["heading_hierarchy"] and
- "next" in aos_hit['_source']["metadata"]["heading_hierarchy"]):
- task = asyncio.create_task(
- self.__ainvoke_get_context(
- aos_hit['_source']["metadata"]["heading_hierarchy"]["previous"],
- aos_hit['_source']["metadata"]["heading_hierarchy"]["next"],
- context_size,
- loop))
- task_list.append(task)
+ if context_size:
+ task = asyncio.create_task(
+ self.__ainvoke_get_context(
+ aos_hit,
+ context_size,
+ loop))
+ task_list.append(task)
return await asyncio.gather(*task_list)
@timeit
@@ -383,15 +500,17 @@ def organize_results(self, response, aos_index=None, source_field="file_path", t
@timeit
def _get_relevant_documents(self, question: Dict, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
- query = question["query"]
+ query = question[self.query_key]
+ if "query_lang" in question and question["query_lang"] != self.lang and "translated_text" in question:
+ query = question["translated_text"]
debug_info = question["debug_info"]
opensearch_knn_results = []
- query_embedding = get_relevance_embedding(query, self.lang, self.embedding_model_endpoint)
+ query_repr = get_relevance_embedding(query, self.lang, self.embedding_model_endpoint, self.model_type)
filter = get_filter_list(question)
opensearch_knn_response = aos_client.search(
index_name=self.index,
query_type="knn",
- query_term=query_embedding,
+ query_term=query_repr["dense_vecs"],
field=self.vector_field,
size=self.top_k,
filter=filter
@@ -405,7 +524,7 @@ def _get_relevant_documents(self, question: Dict, *, run_manager: CallbackManage
# 3. combine these two opensearch_knn_response and opensearch_query_response
final_results = opensearch_knn_results + opensearch_query_results
- debug_info[f"knowledge_qa_knn_recall_{self.index}_{self.lang}"] = remove_redundancy_debug_info(final_results)
+ debug_info[f"qd-knn-recall-{self.index}-{self.lang}"] = remove_redundancy_debug_info(final_results)
doc_list = []
content_set = set()
@@ -447,4 +566,5 @@ def index_results_format(docs:list, threshold=-1):
"question": doc.metadata["question"]})
# output = {"answer": json.dumps(results, ensure_ascii=False), "sources": [], "contexts": []}
output = {"answer": results, "sources": [], "contexts": [], "context_docs": [], "context_sources": []}
- return output
\ No newline at end of file
+ return output
+
diff --git a/source/lambda/executor/utils/sm_utils.py b/source/lambda/executor/utils/sm_utils.py
index 4ced8909..98299407 100644
--- a/source/lambda/executor/utils/sm_utils.py
+++ b/source/lambda/executor/utils/sm_utils.py
@@ -84,6 +84,7 @@ def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json['outputs']
+
class LineIterator:
"""
A helper class for parsing the byte stream input.
@@ -378,14 +379,25 @@ def SagemakerEndpointVectorOrCross(prompt: str, endpoint_name: str, region_name:
client=client,
endpoint_name=endpoint_name,
content_handler=content_handler
- # endpoint_name=endpoint_name,
- # region_name=region_name,
- # content_handler=content_handler
)
query_result = embeddings.embed_query(prompt)
return query_result
elif model_type == "cross":
content_handler = crossContentHandler()
+ elif model_type == "m3":
+ content_handler = vectorContentHandler()
+ model_kwargs = {}
+ model_kwargs['batch_size'] = 12
+ model_kwargs['max_length'] = 512
+ model_kwargs['return_type'] = 'all'
+ embeddings = SagemakerEndpointEmbeddings(
+ client=client,
+ endpoint_name=endpoint_name,
+ content_handler=content_handler,
+ model_kwargs=model_kwargs
+ )
+ query_result = embeddings.embed_query(prompt)
+ return query_result
elif model_type == "answer":
content_handler = answerContentHandler()
elif model_type == "rerank":
diff --git a/source/lambda/executor/utils/workspace_utils.py b/source/lambda/executor/utils/workspace_utils.py
new file mode 100644
index 00000000..8348737c
--- /dev/null
+++ b/source/lambda/executor/utils/workspace_utils.py
@@ -0,0 +1,141 @@
+import json
+import logging
+import os
+import uuid
+from datetime import datetime
+from typing import List
+
+import boto3
+
+WORKSPACE_OBJECT_TYPE = "workspace"
+
+
+class WorkspaceManager:
+ def __init__(self, workspace_table):
+ self.workspace_table = workspace_table
+
+ def get_workspace(self, workspace_id: str):
+ response = self.workspace_table.get_item(
+ Key={"workspace_id": workspace_id, "object_type": WORKSPACE_OBJECT_TYPE}
+ )
+ item = response.get("Item")
+
+ return item
+
+ def get_workspace_id(self, workspace_name: str, embeddings_model_name: str):
+ response = self.workspace_table.scan(
+ FilterExpression="name = :name and embeddings_model_name = :embeddings_model_name",
+ ExpressionAttributeValues={
+ ":name": workspace_name,
+ ":embeddings_model_name": embeddings_model_name,
+ },
+ )
+ items = response.get("Items")
+
+ if items:
+ return items[0]["workspace_id"]
+ else:
+ return None
+
+ def create_workspace_open_search(
+ self,
+ workspace_id: str,
+ embeddings_model_endpoint: str,
+ embeddings_model_provider: str,
+ embeddings_model_name: str,
+ embeddings_model_dimensions: int,
+ languages: List[str],
+ workspace_file_types: List[str],
+ open_search_index_name: str = None,
+ ):
+
+ open_search_index_name = (
+ f"{workspace_id}_index"
+ if not open_search_index_name
+ else open_search_index_name
+ )
+ timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+
+ item = {
+ "workspace_id": workspace_id,
+ "object_type": WORKSPACE_OBJECT_TYPE,
+ "format_version": 1,
+ "name": workspace_id,
+ "engine": "opensearch",
+ "status": "submitted",
+ "embeddings_model_endpoint": embeddings_model_endpoint,
+ "embeddings_model_provider": embeddings_model_provider,
+ "embeddings_model_name": embeddings_model_name,
+ "embeddings_model_dimensions": embeddings_model_dimensions,
+ "languages": languages,
+ "open_search_index_name": open_search_index_name,
+ "workspace_file_types": workspace_file_types,
+ "metric": "l2",
+ "aoss_engine": "nmslib",
+ "documents": 0,
+ "vectors": 0,
+ "size_in_bytes": 0,
+ "created_at": timestamp,
+ "updated_at": timestamp,
+ }
+
+ response = self.workspace_table.put_item(Item=item)
+
+ logging.info(f"Created workspace with response: {response}")
+
+ return open_search_index_name
+
+ def update_workspace_open_search(
+ self,
+ workspace_id: str,
+ embeddings_model_endpoint: str,
+ embeddings_model_provider: str,
+ embeddings_model_name: str,
+ embeddings_model_dimensions: int,
+ languages: List[str],
+ workspace_file_types: List[str],
+ open_search_index_name: str = None,
+ ):
+ timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+
+ item = self.get_workspace(workspace_id)
+ # If the item not exist, create the item
+ if not item:
+ open_search_index_name = self.create_workspace_open_search(
+ workspace_id,
+ embeddings_model_endpoint,
+ embeddings_model_provider,
+ embeddings_model_name,
+ embeddings_model_dimensions,
+ languages,
+ workspace_file_types,
+ open_search_index_name,
+ )
+
+ else:
+ # Get the current workspace_file_types, or an empty list if it doesn't exist
+ current_workspace_file_types = item.get("workspace_file_types", [])
+ open_search_index_name = item.get("open_search_index_name")
+
+ # Append the new workspace_file_types and remove duplicates
+ updated_workspace_file_types = list(
+ set(current_workspace_file_types + workspace_file_types)
+ )
+
+ # Update the item
+ response = self.workspace_table.update_item(
+ Key={
+ "workspace_id": workspace_id,
+ "object_type": WORKSPACE_OBJECT_TYPE,
+ },
+ UpdateExpression="SET workspace_file_types = :wft, updated_at = :uat",
+ ExpressionAttributeValues={
+ ":wft": updated_workspace_file_types,
+ ":uat": timestamp,
+ },
+ ReturnValues="ALL_NEW",
+ )
+
+ logging.info(f"Updated workspace with response: {response}")
+
+ return open_search_index_name
diff --git a/source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl b/source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl
index b22689da..a5317452 100644
Binary files a/source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl and b/source/lambda/job/dep/dist/llm_bot_dep-0.1.0-py3-none-any.whl differ
diff --git a/source/lambda/job/dep/llm_bot_dep/ddb_utils.py b/source/lambda/job/dep/llm_bot_dep/ddb_utils.py
new file mode 100644
index 00000000..8348737c
--- /dev/null
+++ b/source/lambda/job/dep/llm_bot_dep/ddb_utils.py
@@ -0,0 +1,141 @@
+import json
+import logging
+import os
+import uuid
+from datetime import datetime
+from typing import List
+
+import boto3
+
+WORKSPACE_OBJECT_TYPE = "workspace"
+
+
+class WorkspaceManager:
+ def __init__(self, workspace_table):
+ self.workspace_table = workspace_table
+
+ def get_workspace(self, workspace_id: str):
+ response = self.workspace_table.get_item(
+ Key={"workspace_id": workspace_id, "object_type": WORKSPACE_OBJECT_TYPE}
+ )
+ item = response.get("Item")
+
+ return item
+
+ def get_workspace_id(self, workspace_name: str, embeddings_model_name: str):
+ response = self.workspace_table.scan(
+ FilterExpression="name = :name and embeddings_model_name = :embeddings_model_name",
+ ExpressionAttributeValues={
+ ":name": workspace_name,
+ ":embeddings_model_name": embeddings_model_name,
+ },
+ )
+ items = response.get("Items")
+
+ if items:
+ return items[0]["workspace_id"]
+ else:
+ return None
+
+ def create_workspace_open_search(
+ self,
+ workspace_id: str,
+ embeddings_model_endpoint: str,
+ embeddings_model_provider: str,
+ embeddings_model_name: str,
+ embeddings_model_dimensions: int,
+ languages: List[str],
+ workspace_file_types: List[str],
+ open_search_index_name: str = None,
+ ):
+
+ open_search_index_name = (
+ f"{workspace_id}_index"
+ if not open_search_index_name
+ else open_search_index_name
+ )
+ timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+
+ item = {
+ "workspace_id": workspace_id,
+ "object_type": WORKSPACE_OBJECT_TYPE,
+ "format_version": 1,
+ "name": workspace_id,
+ "engine": "opensearch",
+ "status": "submitted",
+ "embeddings_model_endpoint": embeddings_model_endpoint,
+ "embeddings_model_provider": embeddings_model_provider,
+ "embeddings_model_name": embeddings_model_name,
+ "embeddings_model_dimensions": embeddings_model_dimensions,
+ "languages": languages,
+ "open_search_index_name": open_search_index_name,
+ "workspace_file_types": workspace_file_types,
+ "metric": "l2",
+ "aoss_engine": "nmslib",
+ "documents": 0,
+ "vectors": 0,
+ "size_in_bytes": 0,
+ "created_at": timestamp,
+ "updated_at": timestamp,
+ }
+
+ response = self.workspace_table.put_item(Item=item)
+
+ logging.info(f"Created workspace with response: {response}")
+
+ return open_search_index_name
+
+ def update_workspace_open_search(
+ self,
+ workspace_id: str,
+ embeddings_model_endpoint: str,
+ embeddings_model_provider: str,
+ embeddings_model_name: str,
+ embeddings_model_dimensions: int,
+ languages: List[str],
+ workspace_file_types: List[str],
+ open_search_index_name: str = None,
+ ):
+ timestamp = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+
+ item = self.get_workspace(workspace_id)
+ # If the item not exist, create the item
+ if not item:
+ open_search_index_name = self.create_workspace_open_search(
+ workspace_id,
+ embeddings_model_endpoint,
+ embeddings_model_provider,
+ embeddings_model_name,
+ embeddings_model_dimensions,
+ languages,
+ workspace_file_types,
+ open_search_index_name,
+ )
+
+ else:
+ # Get the current workspace_file_types, or an empty list if it doesn't exist
+ current_workspace_file_types = item.get("workspace_file_types", [])
+ open_search_index_name = item.get("open_search_index_name")
+
+ # Append the new workspace_file_types and remove duplicates
+ updated_workspace_file_types = list(
+ set(current_workspace_file_types + workspace_file_types)
+ )
+
+ # Update the item
+ response = self.workspace_table.update_item(
+ Key={
+ "workspace_id": workspace_id,
+ "object_type": WORKSPACE_OBJECT_TYPE,
+ },
+ UpdateExpression="SET workspace_file_types = :wft, updated_at = :uat",
+ ExpressionAttributeValues={
+ ":wft": updated_workspace_file_types,
+ ":uat": timestamp,
+ },
+ ReturnValues="ALL_NEW",
+ )
+
+ logging.info(f"Updated workspace with response: {response}")
+
+ return open_search_index_name
diff --git a/source/lambda/job/dep/llm_bot_dep/embeddings.py b/source/lambda/job/dep/llm_bot_dep/embeddings.py
new file mode 100644
index 00000000..f28b61b3
--- /dev/null
+++ b/source/lambda/job/dep/llm_bot_dep/embeddings.py
@@ -0,0 +1,30 @@
+def get_embedding_info(embedding_endpoint_name):
+ """
+ Get the embedding info from the endpoint name
+ """
+ # Get the embedding info from the endpoint name
+ if "bge-large-zh" in embedding_endpoint_name:
+ embeddings_model_provider = "BAAI"
+ embeddings_model_name = "bge-large-zh-v1-5"
+ embeddings_model_dimensions = 1024
+
+ elif "bge-large-en" in embedding_endpoint_name:
+ embeddings_model_provider = "BAAI"
+ embeddings_model_name = "bge-large-en-v1-5"
+ embeddings_model_dimensions = 1024
+
+ elif "bge-m3" in embedding_endpoint_name:
+ embeddings_model_provider = "BAAI"
+ embeddings_model_name = "bge-m3"
+ embeddings_model_dimensions = 1024
+
+ else:
+ embeddings_model_provider = "Not Found"
+ embeddings_model_name = "Not Found"
+ embeddings_model_dimensions = 1024
+
+ return (
+ embeddings_model_provider,
+ embeddings_model_name,
+ embeddings_model_dimensions,
+ )
diff --git a/source/lambda/job/dep/llm_bot_dep/sm_utils.py b/source/lambda/job/dep/llm_bot_dep/sm_utils.py
index ca6c01ec..a00cf885 100644
--- a/source/lambda/job/dep/llm_bot_dep/sm_utils.py
+++ b/source/lambda/job/dep/llm_bot_dep/sm_utils.py
@@ -1,22 +1,25 @@
"""
Helper functions for using Samgemaker Endpoint via LangChain
"""
-import sys
-import time
+
import json
import logging
+import sys
+import time
import traceback
-from typing import List, Dict, Any, Optional
+from typing import Any, Dict, List, Optional
+
+from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
-from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.utils import enforce_stop_tokens
logger = logging.getLogger()
# logging.basicConfig(format='%(asctime)s,%(module)s,%(processName)s,%(levelname)s,%(message)s', level=logging.INFO, stream=sys.stderr)
logger.setLevel(logging.INFO)
+
# extend the SagemakerEndpointEmbeddings class from langchain to provide a custom embedding function, wrap the embedding & injection logic into a single class
class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):
def embed_documents(
@@ -37,13 +40,16 @@ def embed_documents(
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
st = time.time()
for i in range(0, len(texts), _chunk_size):
- response = self._embedding_func(texts[i:i + _chunk_size])
+ response = self._embedding_func(texts[i : i + _chunk_size])
results.extend(response)
time_taken = time.time() - st
- logger.info(f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}")
+ logger.info(
+ f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}"
+ )
return results
+
class SagemakerEndpointEmbeddingsJumpStartDGR(SagemakerEndpointEmbeddings):
def embed_documents(
self, texts: List[str], chunk_size: int = 5
@@ -63,7 +69,9 @@ def embed_documents(
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
st = time.time()
for i in range(0, len(texts), _chunk_size):
- embedding_texts = [text[:(512-56)] for text in texts[i:i + _chunk_size]]
+ embedding_texts = [
+ text[: (512 - 56)] for text in texts[i : i + _chunk_size]
+ ]
try:
response = self._embedding_func(embedding_texts)
except Exception as error:
@@ -71,7 +79,9 @@ def embed_documents(
print(f"embedding endpoint error: {texts}", error)
results.extend(response)
time_taken = time.time() - st
- logger.info(f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}")
+ logger.info(
+ f"got results for {len(texts)} in {time_taken}s, length of embeddings list is {len(results)}"
+ )
return results
@@ -82,7 +92,7 @@ class ContentHandler(EmbeddingsContentHandler):
def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
input_str = json.dumps({"inputs": prompt, **model_kwargs})
- return input_str.encode('utf-8')
+ return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
@@ -92,46 +102,99 @@ def transform_output(self, output: bytes) -> str:
return embeddings
-def create_embedding_with_multiple_model(embeddings_model_list: List[str], aws_region: str, file_type: str):
+def create_embeddings_with_single_model(
+ embeddings_model: str, aws_region: str, file_type: str
+):
+ embeddings_result = None
+ if file_type.lower() == "jsonl":
+ if "zh" in embeddings_model.lower():
+ content_handler = SimilarityZhContentHandler()
+ elif "en" in embeddings_model.lower():
+ content_handler = SimilarityEnContentHandler()
+ else:
+ if "zh" in embeddings_model.lower():
+ content_handler = RelevanceZhContentHandler()
+ elif "en" in embeddings_model.lower():
+ content_handler = RelevanceEnContentHandler()
+
+ embeddings_result = create_sagemaker_embeddings_from_js_model(
+ embeddings_model, aws_region, content_handler
+ )
+
+ return embeddings_result
+
+
+def create_embeddings_with_m3_model(
+ embeddings_model: str, aws_region: str, file_type: str
+):
+ embeddings_result = None
+ if file_type.lower() == "jsonl":
+ content_handler = SimilarityM3ContentHandler()
+ else:
+ content_handler = RelevanceM3ContentHandler()
+
+ embeddings_result = create_sagemaker_embeddings_from_js_model(
+ embeddings_model, aws_region, content_handler
+ )
+
+ return embeddings_result
+
+
+def create_embedding_with_multiple_model(
+ embeddings_model_list: List[str], aws_region: str, file_type: str
+):
embedding_dict = {}
if file_type.lower() == "jsonl":
for embedding_model in embeddings_model_list:
if "zh" in embedding_model.lower():
content_handler_zh = SimilarityZhContentHandler()
- embedding_zh = create_sagemaker_embeddings_from_js_model(embedding_model, aws_region, content_handler_zh)
+ embedding_zh = create_sagemaker_embeddings_from_js_model(
+ embedding_model, aws_region, content_handler_zh
+ )
embedding_dict["zh"] = embedding_zh
elif "en" in embedding_model.lower():
content_handler_en = SimilarityEnContentHandler()
- embedding_en = create_sagemaker_embeddings_from_js_model(embedding_model, aws_region, content_handler_en)
+ embedding_en = create_sagemaker_embeddings_from_js_model(
+ embedding_model, aws_region, content_handler_en
+ )
embedding_dict["en"] = embedding_en
else:
for embedding_model in embeddings_model_list:
if "zh" in embedding_model.lower():
content_handler_zh = RelevanceZhContentHandler()
- embedding_zh = create_sagemaker_embeddings_from_js_model(embedding_model, aws_region, content_handler_zh)
+ embedding_zh = create_sagemaker_embeddings_from_js_model(
+ embedding_model, aws_region, content_handler_zh
+ )
embedding_dict["zh"] = embedding_zh
elif "en" in embedding_model.lower():
content_handler_en = RelevanceEnContentHandler()
- embedding_en = create_sagemaker_embeddings_from_js_model(embedding_model, aws_region, content_handler_en)
+ embedding_en = create_sagemaker_embeddings_from_js_model(
+ embedding_model, aws_region, content_handler_en
+ )
embedding_dict["en"] = embedding_en
return embedding_dict
-def create_sagemaker_embeddings_from_js_model(embeddings_model_endpoint_name: str, aws_region: str, content_handler) -> SagemakerEndpointEmbeddingsJumpStart:
- # all set to create the objects for the ContentHandler and
+def create_sagemaker_embeddings_from_js_model(
+ embeddings_model_endpoint_name: str, aws_region: str, content_handler
+) -> SagemakerEndpointEmbeddingsJumpStart:
+ # all set to create the objects for the ContentHandler and
# SagemakerEndpointEmbeddingsJumpStart classes
- logger.info(f'content_handler: {content_handler}, embeddings_model_endpoint_name: {embeddings_model_endpoint_name}, aws_region: {aws_region}')
+ logger.info(
+ f"content_handler: {content_handler}, embeddings_model_endpoint_name: {embeddings_model_endpoint_name}, aws_region: {aws_region}"
+ )
# note the name of the LLM Sagemaker endpoint, this is the model that we would
# be using for generating the embeddings
embeddings = SagemakerEndpointEmbeddingsJumpStart(
- endpoint_name = embeddings_model_endpoint_name,
- region_name = aws_region,
- content_handler = content_handler
+ endpoint_name=embeddings_model_endpoint_name,
+ region_name=aws_region,
+ content_handler=content_handler,
)
return embeddings
+
# Migrate the class from sm_utils.py in executor to here, there are 3 models including vector, cross and answer wrapper into class SagemakerEndpointVectorOrCross. TODO, to merge the class along with the previous class SagemakerEndpointEmbeddingsJumpStart
class vectorContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
@@ -145,152 +208,15 @@ def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["sentence_embeddings"]
-# class crossContentHandler(LLMContentHandler):
-# content_type = "application/json"
-# accepts = "application/json"
-
-# def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
-# input_str = json.dumps({"inputs": prompt, "docs":model_kwargs["context"]})
-# return input_str.encode('utf-8')
-
-# def transform_output(self, output: bytes) -> str:
-# response_json = json.loads(output.read().decode("utf-8"))
-# return response_json['scores'][0][1]
-
-# class answerContentHandler(LLMContentHandler):
-# content_type = "application/json"
-# accepts = "application/json"
-
-# def transform_input(self, question: str, model_kwargs: Dict) -> bytes:
-
-# template_1 = '以下context xml tag内的文本内容为背景知识:\n\n{context}\n\n请根据背景知识, 回答这个问题:{question}'
-# context = model_kwargs["context"]
-
-# if len(context) == 0:
-# prompt = question
-# else:
-# prompt = template_1.format(context = model_kwargs["context"], question = question)
-
-# input_str = json.dumps({"inputs": prompt,
-# "history": model_kwargs["history"],
-# "parameters": model_kwargs["parameters"]})
-# return input_str.encode('utf-8')
-# def transform_output(self, output: bytes) -> str:
-# response_json = json.loads(output.read().decode("utf-8"))
-# return response_json['outputs']
-
-class LineIterator:
- """
- A helper class for parsing the byte stream input.
-
- The output of the model will be in the following format:
- ```
- b'{"outputs": [" a"]}\n'
- b'{"outputs": [" challenging"]}\n'
- b'{"outputs": [" problem"]}\n'
- ...
- ```
-
- While usually each PayloadPart event from the event stream will contain a byte array
- with a full json, this is not guaranteed and some of the json objects may be split across
- PayloadPart events. For example:
- ```
- {'PayloadPart': {'Bytes': b'{"outputs": '}}
- {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
- ```
-
- This class accounts for this by concatenating bytes written via the 'write' function
- and then exposing a method which will return lines (ending with a '\n' character) within
- the buffer via the 'scan_lines' function. It maintains the position of the last read
- position to ensure that previous bytes are not exposed again.
- """
-
- def __init__(self, stream):
- self.byte_iterator = iter(stream)
- self.buffer = io.BytesIO()
- self.read_pos = 0
-
- def __iter__(self):
- return self
-
- def __next__(self):
- while True:
- self.buffer.seek(self.read_pos)
- line = self.buffer.readline()
- if line and line[-1] == ord('\n'):
- self.read_pos += len(line)
- return line[:-1]
- try:
- chunk = next(self.byte_iterator)
- except StopIteration:
- if self.read_pos < self.buffer.getbuffer().nbytes:
- continue
- raise
- if 'PayloadPart' not in chunk:
- print('Unknown event type:' + chunk)
- continue
- self.buffer.seek(0, io.SEEK_END)
- self.buffer.write(chunk['PayloadPart']['Bytes'])
-
-class SagemakerEndpointStreaming(SagemakerEndpoint):
- # override the _call function to support streaming function with invoke_endpoint_with_response_stream
- def _call(
- self,
- prompt: str,
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> str:
- """Call out to Sagemaker inference endpoint.
-
- Args:
- prompt: The prompt to pass into the model.
- stop: Optional list of stop words to use when generating.
-
- Returns:
- The string generated by the model.
-
- Example:
- .. code-block:: python
-
- response = se("Tell me a joke.")
- """
- _model_kwargs = self.model_kwargs or {}
- _model_kwargs = {**_model_kwargs, **kwargs}
- _endpoint_kwargs = self.endpoint_kwargs or {}
-
- body = self.content_handler.transform_input(prompt, _model_kwargs)
- # the content type should be application/json if we are using LMI container
- content_type = self.content_handler.content_type
- accepts = self.content_handler.accepts
-
- # send request
- try:
- response = self.client.invoke_endpoint_with_response_stream(
- EndpointName=self.endpoint_name,
- Body=body,
- ContentType=content_type,
- Accept=accepts,
- **_endpoint_kwargs,
- )
- except Exception as e:
- raise ValueError(f"Error raised by inference endpoint: {e}")
-
- # transform_output is not used here because the response is a stream
- for line in LineIterator(response['Body']):
- resp = json.loads(line)
- logging.info(resp.get("outputs")[0], end='')
-
- # enforce stop tokens if they are provided
- if stop is not None:
- # This is a bit hacky, but I can't figure out a better way to enforce
- # stop tokens when making calls to the sagemaker endpoint.
- text = enforce_stop_tokens(text, stop)
-
- return resp.get("outputs")[0]
-
-def SagemakerEndpointVectorOrCross(prompt: str, endpoint_name: str, region_name: str, model_type: str, stop: List[str], **kwargs) -> SagemakerEndpoint:
+def SagemakerEndpointVectorOrCross(
+ prompt: str,
+ endpoint_name: str,
+ region_name: str,
+ model_type: str,
+ stop: List[str],
+ **kwargs,
+) -> SagemakerEndpoint:
"""
original class invocation:
response = self.client.invoke_endpoint(
@@ -310,15 +236,11 @@ def SagemakerEndpointVectorOrCross(prompt: str, endpoint_name: str, region_name:
)
query_result = embeddings.embed_query(prompt)
return query_result
- # elif model_type == "cross":
- # content_handler = crossContentHandler()
- # elif model_type == "answer":
- # content_handler = answerContentHandler()
- # TODO: replace with SagemakerEndpointStreaming
+
genericModel = SagemakerEndpoint(
- endpoint_name = endpoint_name,
- region_name = region_name,
- content_handler = content_handler
+ endpoint_name=endpoint_name,
+ region_name=region_name,
+ content_handler=content_handler,
)
return genericModel(prompt=prompt, stop=stop, **kwargs)
@@ -333,16 +255,17 @@ def transform_input(self, prompt, model_kwargs={}) -> bytes:
new_prompt = [p for p in prompt]
input_str = json.dumps({"inputs": new_prompt, **model_kwargs})
- return input_str.encode('utf-8')
+ return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
embeddings = response_json["sentence_embeddings"]
if len(embeddings) == 1:
return [embeddings[0]]
-
+
return embeddings
+
class RelevanceZhContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
@@ -352,55 +275,114 @@ def transform_input(self, prompt, model_kwargs={}) -> bytes:
new_prompt = ["为这个句子生成表示以用于检索相关文章:" + p for p in prompt]
input_str = json.dumps({"inputs": new_prompt, **model_kwargs})
- return input_str.encode('utf-8')
+ return input_str.encode("utf-8")
+
+ def transform_output(self, output: bytes) -> str:
+ response_json = json.loads(output.read().decode("utf-8"))
+ embeddings = response_json["sentence_embeddings"]
+ if len(embeddings) == 1:
+ return [embeddings[0]]
+
+ return embeddings
+
+
+class SimilarityZhContentHandler(EmbeddingsContentHandler):
+ content_type = "application/json"
+ accepts = "application/json"
+
+ def transform_input(self, prompt, model_kwargs={}) -> bytes:
+ # add bge_prompt to each element in prompt
+ new_prompt = [p for p in prompt]
+ input_str = json.dumps({"inputs": new_prompt, **model_kwargs})
+
+ return input_str.encode("utf-8")
+
+ def transform_output(self, output: bytes) -> str:
+ response_json = json.loads(output.read().decode("utf-8"))
+ embeddings = response_json["sentence_embeddings"]
+ if len(embeddings) == 1:
+ return [embeddings[0]]
+
+ return embeddings
+
+
+class RelevanceM3ContentHandler(EmbeddingsContentHandler):
+ content_type = "application/json"
+ accepts = "application/json"
+
+ def transform_input(self, prompt, model_kwargs={}) -> bytes:
+ # add bge_prompt to each element in prompt
+ model_kwargs = {}
+ model_kwargs["batch_size"] = 12
+ model_kwargs["max_length"] = 512
+ model_kwargs["return_type"] = "all"
+ input_str = json.dumps({"inputs": prompt, **model_kwargs})
+
+ return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
embeddings = response_json["sentence_embeddings"]
+
if len(embeddings) == 1:
return [embeddings[0]]
-
+
return embeddings
-class SimilarityEnContentHandler(EmbeddingsContentHandler):
+
+class SimilarityM3ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt, model_kwargs={}) -> bytes:
# add bge_prompt to each element in prompt
new_prompt = [p for p in prompt]
+ model_kwargs = {}
+ model_kwargs["batch_size"] = 12
+ model_kwargs["max_length"] = 512
+ model_kwargs["return_type"] = "all"
input_str = json.dumps({"inputs": new_prompt, **model_kwargs})
- return input_str.encode('utf-8')
+ return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
embeddings = response_json["sentence_embeddings"]
if len(embeddings) == 1:
return [embeddings[0]]
-
+
return embeddings
+
class RelevanceEnContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt, model_kwargs={}) -> bytes:
# add bge_prompt to each element in prompt
- new_prompt = ["Represent this sentence for searching relevant passages:" + p for p in prompt]
+ new_prompt = [
+ "Represent this sentence for searching relevant passages:" + p
+ for p in prompt
+ ]
input_str = json.dumps({"inputs": new_prompt, **model_kwargs})
- return input_str.encode('utf-8')
+ return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
embeddings = response_json["sentence_embeddings"]
if len(embeddings) == 1:
return [embeddings[0]]
-
+
return embeddings
-def create_sagemaker_embeddings_from_js_model_dgr(embeddings_model_endpoint_name: str, aws_region: str, lang: str = "zh", type: str = "similarity") -> SagemakerEndpointEmbeddingsJumpStartDGR:
- # all set to create the objects for the ContentHandler and
+
+def create_sagemaker_embeddings_from_js_model_dgr(
+ embeddings_model_endpoint_name: str,
+ aws_region: str,
+ lang: str = "zh",
+ type: str = "similarity",
+) -> SagemakerEndpointEmbeddingsJumpStartDGR:
+ # all set to create the objects for the ContentHandler and
# SagemakerEndpointEmbeddingsJumpStart classes
if lang == "zh":
if type == "similarity":
@@ -412,12 +394,14 @@ def create_sagemaker_embeddings_from_js_model_dgr(embeddings_model_endpoint_name
content_handler = SimilarityEnContentHandler()
elif type == "relevance":
content_handler = RelevanceEnContentHandler()
- logger.info(f'content_handler: {content_handler}, embeddings_model_endpoint_name: {embeddings_model_endpoint_name}, aws_region: {aws_region}')
+ logger.info(
+ f"content_handler: {content_handler}, embeddings_model_endpoint_name: {embeddings_model_endpoint_name}, aws_region: {aws_region}"
+ )
# note the name of the LLM Sagemaker endpoint, this is the model that we would
# be using for generating the embeddings
- embeddings = SagemakerEndpointEmbeddingsJumpStartDGR(
- endpoint_name = embeddings_model_endpoint_name,
- region_name = aws_region,
- content_handler = content_handler
+ embeddings = SagemakerEndpointEmbeddingsJumpStartDGR(
+ endpoint_name=embeddings_model_endpoint_name,
+ region_name=aws_region,
+ content_handler=content_handler,
)
- return embeddings
\ No newline at end of file
+ return embeddings
diff --git a/source/lambda/job/dep/llm_bot_dep/splitter_utils.py b/source/lambda/job/dep/llm_bot_dep/splitter_utils.py
index ea810b2b..5936a413 100644
--- a/source/lambda/job/dep/llm_bot_dep/splitter_utils.py
+++ b/source/lambda/job/dep/llm_bot_dep/splitter_utils.py
@@ -1,6 +1,7 @@
import re
import logging
import uuid
+import traceback
from typing import Any, Dict, Iterator, List, Optional, Union
import boto3
from langchain.docstore.document import Document
@@ -160,7 +161,7 @@ def extract_headings(md_content: str):
lines = md_content.split("\n")
id_index_dict = {}
for line in lines:
- match = re.match(r"(#+)(.*)", line)
+ match = re.match(r"\s*(#+)(.*)", line)
if match:
header_index += 1
print(match.group)
@@ -258,6 +259,7 @@ def _get_current_heading_list(self, current_heading, current_heading_level_map):
title_list.append(current_heading_level_map[title_level])
joint_title_list = ' '.join(title_list)
except Exception as e:
+ traceback.print_exc()
print(f"Error: {e}")
return ""
diff --git a/source/lambda/job/dep/llm_bot_dep/storage_utils.py b/source/lambda/job/dep/llm_bot_dep/storage_utils.py
index 7acaab8f..812eb9c9 100644
--- a/source/lambda/job/dep/llm_bot_dep/storage_utils.py
+++ b/source/lambda/job/dep/llm_bot_dep/storage_utils.py
@@ -70,6 +70,7 @@ def save_content_to_s3(s3, document: Document, res_bucket: str, splitting_type:
logger_file = convert_to_logger(document)
# Extract the filename from the file_path in the metadata
file_path = document.metadata.get('file_path', '')
- filename = file_path.split('/')[-1].split('.')[0]
+ # filename = file_path.split('/')[-1].split('.')[0]
+ filename = file_path.replace('s3://', '').replace('/', '-').replace('.', '-')
# RecursiveCharacterTextSplitter have been rewrite to split based on chunk size & overlap, use seperate folder to store the logger file
upload_chunk_to_s3(s3, logger_file, res_bucket, filename, splitting_type)
diff --git a/source/lambda/job/glue-job-script-cn.py b/source/lambda/job/glue-job-script-cn.py
index e208939e..6d90dacb 100644
--- a/source/lambda/job/glue-job-script-cn.py
+++ b/source/lambda/job/glue-job-script-cn.py
@@ -5,12 +5,11 @@
import os
import sys
import time
-from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
import traceback
+from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
import boto3
import chardet
-import nltk
from awsglue.utils import getResolvedOptions
from boto3.dynamodb.conditions import Attr, Key
from langchain.docstore.document import Document
@@ -18,6 +17,8 @@
from langchain.vectorstores import OpenSearchVectorSearch
from llm_bot_dep import sm_utils
from llm_bot_dep.constant import SplittingType
+from llm_bot_dep.ddb_utils import WorkspaceManager
+from llm_bot_dep.embeddings import get_embedding_info
from llm_bot_dep.enhance_utils import EnhanceWithBedrock
from llm_bot_dep.loaders.auto import cb_process_object
from llm_bot_dep.storage_utils import save_content_to_s3
@@ -64,10 +65,8 @@
s3_bucket = args["S3_BUCKET"]
s3_prefix = args["S3_PREFIX"]
aosEndpoint = args["AOS_ENDPOINT"]
-aos_index = args["DOC_INDEX_TABLE"]
-# This index is used for the AOS injection, to allow user customize the index, otherwise default value is "chatbot-index" or set in CloudFormation parameter
-aos_custom_index = args["AOS_INDEX"]
-embeddingModelEndpointList = args["EMBEDDING_MODEL_ENDPOINT"].split(",")
+
+embeddingModelEndpoint = args["EMBEDDING_MODEL_ENDPOINT"]
etlModelEndpoint = args["ETL_MODEL_ENDPOINT"]
region = args["REGION"]
res_bucket = args["RES_BUCKET"]
@@ -76,22 +75,16 @@
# TODO, pass the bucket and prefix need to handle in current job directly
batchIndice = args["BATCH_INDICE"]
processedObjectsTable = args["ProcessedObjectsTable"]
+workspace_name = args["WORKSPACE_NAME"]
+workspaces_table = args["WORKSPACES_TABLE"]
content_type = args["CONTENT_TYPE"]
-_embedding_endpoint_name_list = args["EMBEDDING_MODEL_ENDPOINT"].split(",")
-_embedding_lang_list = args["EMBEDDING_LANG"].split(",")
-_embedding_type_list = args["EMBEDDING_TYPE"].split(",")
-embeddings_model_info_list = []
-for endpoint_name, lang, endpoint_type in zip(
- _embedding_endpoint_name_list, _embedding_lang_list, _embedding_type_list
-):
- embeddings_model_info_list.append(
- {"endpoint_name": endpoint_name, "lang": lang, "type": endpoint_type}
- )
s3 = boto3.client("s3")
smr_client = boto3.client("sagemaker-runtime")
dynamodb = boto3.resource("dynamodb")
table = dynamodb.Table(processedObjectsTable)
+workspaces_table = dynamodb.Table(workspaces_table)
+workspaces_manager = WorkspaceManager(workspaces_table)
ENHANCE_CHUNK_SIZE = 25000
# Make it 3600s for debugging purpose
@@ -107,8 +100,8 @@
)
MAX_OS_DOCS_PER_PUT = 8
-# Set the NLTK data path to the /tmp directory for AWS Glue jobs
-nltk.data.path.append("/tmp/nltk_data")
+
+supported_file_types = ["pdf", "txt", "doc", "md", "html", "json", "jsonl", "csv"]
def decode_file_content(content: str, default_encoding: str = "utf-8"):
@@ -153,40 +146,6 @@ def iterate_s3_files(bucket: str, prefix: str) -> Generator:
)
currentIndice += 1
continue
- """
- WHY this code block is commented out? we used to record the processed object in DynamoDB in case of redundant operation for the same object
- """
- # # Truncate to seconds with round()
- # current_time = int(round(time.time()))
- # # Check for redundancy and expiry
- # response = table.query(
- # KeyConditionExpression=Key("ObjectKey").eq(key),
- # ScanIndexForward=False, # Sort by ProcessTimestamp in descending order
- # Limit=1, # We only need the latest record
- # )
-
- # # If the object is found and has not expired, skip processing
- # if (
- # response["Items"]
- # and response["Items"][0]["ExpiryTimestamp"] > current_time
- # ):
- # logger.info(f"Object {key} has not expired yet and will be skipped.")
- # continue
-
- # # Record the processing of the S3 object with an updated expiry timestamp, and each job only update single object in table. TODO, current assume the object will be handled successfully
- # expiry_timestamp = current_time + OBJECT_EXPIRY_TIME
- # try:
- # table.put_item(
- # Item={
- # "ObjectKey": key,
- # "ProcessTimestamp": current_time,
- # "Bucket": bucket,
- # "Prefix": "/".join(key.split("/")[:-1]),
- # "ExpiryTimestamp": expiry_timestamp,
- # }
- # )
- # except Exception as e:
- # logger.error(f"Error recording processed of S3 object {key}: {e}")
file_type = key.split(".")[-1].lower() # Extract file extension
response = s3.get_object(Bucket=bucket, Key=key)
@@ -244,7 +203,7 @@ def batch_generator(generator, batch_size: int):
def aos_injection(
content: List[Document],
- embeddingModelEndpointList: List[str],
+ embeddingModelEndpoint: str,
aosEndpoint: str,
index_name: str,
file_type: str,
@@ -270,8 +229,8 @@ def aos_injection(
Note:
"""
- embedding_list = sm_utils.create_embedding_with_multiple_model(
- embeddingModelEndpointList, region, file_type
+ embeddings = sm_utils.create_embeddings_with_single_model(
+ embeddingModelEndpoint, region, file_type
)
def chunk_generator(
@@ -330,43 +289,38 @@ def chunk_generator(
document.metadata["complete_heading"] + " " + document.page_content
)
else:
- document.page_content = (document.page_content)
+ document.page_content = document.page_content
+
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def _aos_injection(document: Document) -> Document:
- # If user customize the index, use the customized index as high priority, NOTE the custom index will be created with default AOS mapping in LangChain, use API to create the index with customized mapping before running the job if you want to customize the mapping
- if aos_custom_index:
- index_name = aos_custom_index
- for embedding in embedding_list.values():
- document.metadata["embedding_endpoint_name"] = (
- embedding.endpoint_name
- )
- docsearch = OpenSearchVectorSearch(
- index_name=index_name,
- embedding_function=embedding,
- opensearch_url="https://{}".format(aosEndpoint),
- http_auth=awsauth,
- use_ssl=True,
- verify_certs=True,
- connection_class=RequestsHttpConnection,
- )
+ document.metadata["embedding_endpoint_name"] = embeddingModelEndpoint
+ docsearch = OpenSearchVectorSearch(
+ index_name=index_name,
+ embedding_function=embeddings,
+ opensearch_url="https://{}".format(aosEndpoint),
+ http_auth=awsauth,
+ use_ssl=True,
+ verify_certs=True,
+ connection_class=RequestsHttpConnection,
+ )
+ logger.info(
+ "Adding documents %s to OpenSearch with index %s",
+ document,
+ index_name,
+ )
+ # TODO: add endpoint name as a metadata of document
+ try:
+ # TODO, consider the max retry and initial backoff inside helper.bulk operation instead of using original LangChain
+ docsearch.add_documents(documents=[document])
+ except Exception as e:
logger.info(
- "Adding documents %s to OpenSearch with index %s",
- document,
- index_name,
+ f"Catch exception when adding document to OpenSearch: {e}"
)
- # TODO: add endpoint name as a metadata of document
- try:
- # TODO, consider the max retry and initial backoff inside helper.bulk operation instead of using original LangChain
- docsearch.add_documents(documents=[document])
- except Exception as e:
- logger.info(
- f"Catch exception when adding document to OpenSearch: {e}"
- )
- logger.info("Retry statistics: %s", _aos_injection.retry.statistics)
+ logger.info("Retry statistics: %s", _aos_injection.retry.statistics)
# logger.info("Adding documents %s to OpenSearch with index %s", document, index_name)
save_content_to_s3(s3, document, res_bucket, SplittingType.CHUNK.value)
@@ -376,96 +330,91 @@ def _aos_injection(document: Document) -> Document:
# Main function to be called by Glue job script
def main():
logger.info("Starting Glue job with passing arguments: %s", args)
- # Check if offline mode
- if offline == "true" or offline == "false":
- logger.info("Running in offline mode with consideration for large file size...")
- for file_type, file_content, kwargs in iterate_s3_files(s3_bucket, s3_prefix):
- try:
- if file_type == "json":
- kwargs["embeddings_model_info_list"] = embeddings_model_info_list
- kwargs["aos_index"] = aos_index
- kwargs["aosEndpoint"] = aosEndpoint
- kwargs["region"] = region
- kwargs["awsauth"] = awsauth
- kwargs["content_type"] = content_type
- kwargs["max_os_docs_per_put"] = MAX_OS_DOCS_PER_PUT
- res = cb_process_object(s3, file_type, file_content, **kwargs)
+ logger.info("Running in offline mode with consideration for large file size...")
+
+ embeddings_model_provider, embeddings_model_name, embeddings_model_dimensions = (
+ get_embedding_info(embeddingModelEndpoint)
+ )
+
+ for file_type, file_content, kwargs in iterate_s3_files(s3_bucket, s3_prefix):
+ try:
+ res = cb_process_object(s3, file_type, file_content, **kwargs)
+ for document in res:
+ save_content_to_s3(
+ s3, document, res_bucket, SplittingType.SEMANTIC.value
+ )
+
+ # the res is unified to list[Doucment] type, store the res to S3 for observation
+ # TODO, parse the metadata to embed with different index
+ if res:
+ logger.info("Result: %s", res)
+
+ workspace_id, aos_index = workspaces_manager.update_workspace_open_search(
+ workspace_name,
+ embeddings_model_provider,
+ embeddings_model_name,
+ embeddings_model_dimensions,
+ ["zh"],
+ [file_type],
+ )
+
+ gen_chunk_flag = False if file_type == "csv" else True
+ if file_type in supported_file_types:
+ aos_injection(
+ res,
+ embeddingModelEndpoint,
+ aosEndpoint,
+ aos_index,
+ file_type,
+ gen_chunk=gen_chunk_flag,
+ )
+
+ if qa_enhancement == "true":
+ enhanced_prompt_list = []
+ # iterate the document to get the QA pairs
for document in res:
- save_content_to_s3(
- s3, document, res_bucket, SplittingType.SEMANTIC.value
+ # Define your prompt or else it uses default prompt
+ prompt = ""
+ # Make sure the document is Document object
+ logger.info(
+ "Enhancing document type: {} and content: {}".format(
+ type(document), document
+ )
)
-
- # the res is unified to list[Doucment] type, store the res to S3 for observation
- # TODO, parse the metadata to embed with different index
- if res:
- logger.info("Result: %s", res)
- if file_type == "csv":
- # CSV page document has been splited into chunk, no more spliting is needed
- aos_injection(
- res,
- embeddingModelEndpointList,
- aosEndpoint,
- aos_index,
- file_type,
- gen_chunk=False,
+ ewb = EnhanceWithBedrock(prompt, document)
+ # This is should be optional for the user to choose the chunk size
+ document_list = ewb.SplitDocumentByTokenNum(
+ document, ENHANCE_CHUNK_SIZE
)
- elif file_type in ["pdf", "txt", "doc", "md", "html", "json", "jsonl"]:
+ for document in document_list:
+ enhanced_prompt_list = ewb.EnhanceWithClaude(
+ prompt, document, enhanced_prompt_list
+ )
+ logger.info(f"Enhanced prompt: {enhanced_prompt_list}")
+
+ if len(enhanced_prompt_list) > 0:
+ for document in enhanced_prompt_list:
+ save_content_to_s3(
+ s3,
+ document,
+ res_bucket,
+ SplittingType.QA_ENHANCEMENT.value,
+ )
aos_injection(
- res,
- embeddingModelEndpointList,
+ enhanced_prompt_list,
+ embeddingModelEndpoint,
aosEndpoint,
aos_index,
- file_type,
+ "qa",
)
- if qa_enhancement == "true":
- enhanced_prompt_list = []
- # iterate the document to get the QA pairs
- for document in res:
- # Define your prompt or else it uses default prompt
- prompt = ""
- # Make sure the document is Document object
- logger.info(
- "Enhancing document type: {} and content: {}".format(
- type(document), document
- )
- )
- ewb = EnhanceWithBedrock(prompt, document)
- # This is should be optional for the user to choose the chunk size
- document_list = ewb.SplitDocumentByTokenNum(
- document, ENHANCE_CHUNK_SIZE
- )
- for document in document_list:
- enhanced_prompt_list = ewb.EnhanceWithClaude(
- prompt, document, enhanced_prompt_list
- )
- logger.info(f"Enhanced prompt: {enhanced_prompt_list}")
-
- if len(enhanced_prompt_list) > 0:
- for document in enhanced_prompt_list:
- save_content_to_s3(
- s3,
- document,
- res_bucket,
- SplittingType.QA_ENHANCEMENT.value,
- )
- aos_injection(
- enhanced_prompt_list,
- embeddingModelEndpointList,
- aosEndpoint,
- aos_index,
- "qa",
- )
- except Exception as e:
- logger.error(
- "Error processing object %s: %s",
- kwargs["bucket"] + "/" + kwargs["key"],
- e,
- )
- traceback.print_exc()
-
- else:
- logger.info("Running in online mode, assume file number is small...")
+ except Exception as e:
+ logger.error(
+ "Error processing object %s: %s",
+ kwargs["bucket"] + "/" + kwargs["key"],
+ e,
+ )
+ traceback.print_exc()
if __name__ == "__main__":
diff --git a/source/lambda/job/glue-job-script.py b/source/lambda/job/glue-job-script.py
index 56889d3a..5ed3a5b2 100644
--- a/source/lambda/job/glue-job-script.py
+++ b/source/lambda/job/glue-job-script.py
@@ -6,18 +6,52 @@
import sys
import time
import traceback
+import functools
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
import boto3
import chardet
import nltk
-from awsglue.utils import getResolvedOptions
+
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
+
+try:
+ from awsglue.utils import getResolvedOptions
+ args = getResolvedOptions(
+ sys.argv,
+ [
+ "JOB_NAME",
+ "S3_BUCKET",
+ "S3_PREFIX",
+ "AOS_ENDPOINT",
+ "EMBEDDING_MODEL_ENDPOINT",
+ "ETL_MODEL_ENDPOINT",
+ "REGION",
+ "RES_BUCKET",
+ "OFFLINE",
+ "QA_ENHANCEMENT",
+ "BATCH_INDICE",
+ "ProcessedObjectsTable",
+ "WORKSPACE_ID",
+ "WORKSPACE_TABLE",
+ ],
+ )
+except Exception as e:
+ logger.warning("Running locally")
+ sys.path.append("dep")
+ args = json.load(open(sys.argv[1]))
+ args["BATCH_INDICE"] = sys.argv[2]
+
from boto3.dynamodb.conditions import Attr, Key
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import OpenSearchVectorSearch
+
from llm_bot_dep import sm_utils
from llm_bot_dep.constant import SplittingType
+from llm_bot_dep.ddb_utils import WorkspaceManager
+from llm_bot_dep.embeddings import get_embedding_info
from llm_bot_dep.enhance_utils import EnhanceWithBedrock
from llm_bot_dep.loaders.auto import cb_process_object
from llm_bot_dep.storage_utils import save_content_to_s3
@@ -25,8 +59,6 @@
from requests_aws4auth import AWS4Auth
from tenacity import retry, stop_after_attempt, wait_exponential
-logger = logging.getLogger()
-logger.setLevel(logging.INFO)
# Adaption to allow nougat to run in AWS Glue with writable /tmp
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
@@ -34,28 +66,7 @@
os.environ["NLTK_DATA"] = "/tmp/nltk_data"
# Parse arguments
-args = getResolvedOptions(
- sys.argv,
- [
- "JOB_NAME",
- "S3_BUCKET",
- "S3_PREFIX",
- "AOS_ENDPOINT",
- "EMBEDDING_MODEL_ENDPOINT",
- "ETL_MODEL_ENDPOINT",
- "REGION",
- "RES_BUCKET",
- "OFFLINE",
- "QA_ENHANCEMENT",
- "BATCH_INDICE",
- "ProcessedObjectsTable",
- "DOC_INDEX_TABLE",
- "AOS_INDEX",
- "CONTENT_TYPE",
- "EMBEDDING_TYPE",
- "EMBEDDING_LANG",
- ],
-)
+
# Online process triggered by S3 Object create event does not have batch indice
# Set default value for BATCH_INDICE if it doesn't exist
@@ -64,10 +75,8 @@
s3_bucket = args["S3_BUCKET"]
s3_prefix = args["S3_PREFIX"]
aosEndpoint = args["AOS_ENDPOINT"]
-aos_index = args["DOC_INDEX_TABLE"]
-# This index is used for the AOS injection, to allow user customize the index, otherwise default value is "chatbot-index" or set in CloudFormation parameter
-aos_custom_index = args["AOS_INDEX"]
-embeddingModelEndpointList = args["EMBEDDING_MODEL_ENDPOINT"].split(",")
+
+embeddingModelEndpoint = args["EMBEDDING_MODEL_ENDPOINT"]
etlModelEndpoint = args["ETL_MODEL_ENDPOINT"]
region = args["REGION"]
res_bucket = args["RES_BUCKET"]
@@ -76,22 +85,15 @@
# TODO, pass the bucket and prefix need to handle in current job directly
batchIndice = args["BATCH_INDICE"]
processedObjectsTable = args["ProcessedObjectsTable"]
-content_type = args["CONTENT_TYPE"]
-_embedding_endpoint_name_list = args["EMBEDDING_MODEL_ENDPOINT"].split(",")
-_embedding_lang_list = args["EMBEDDING_LANG"].split(",")
-_embedding_type_list = args["EMBEDDING_TYPE"].split(",")
-embeddings_model_info_list = []
-for endpoint_name, lang, endpoint_type in zip(
- _embedding_endpoint_name_list, _embedding_lang_list, _embedding_type_list
-):
- embeddings_model_info_list.append(
- {"endpoint_name": endpoint_name, "lang": lang, "type": endpoint_type}
- )
+workspace_id = args["WORKSPACE_ID"]
+workspace_table = args["WORKSPACE_TABLE"]
s3 = boto3.client("s3")
smr_client = boto3.client("sagemaker-runtime")
dynamodb = boto3.resource("dynamodb")
table = dynamodb.Table(processedObjectsTable)
+workspace_table = dynamodb.Table(workspace_table)
+workspace_manager = WorkspaceManager(workspace_table)
ENHANCE_CHUNK_SIZE = 25000
# Make it 3600s for debugging purpose
@@ -110,6 +112,8 @@
# Set the NLTK data path to the /tmp directory for AWS Glue jobs
nltk.data.path.append("/tmp/nltk_data")
+supported_file_types = ["pdf", "txt", "doc", "md", "html", "json", "jsonl", "csv"]
+
def decode_file_content(content: str, default_encoding: str = "utf-8"):
"""Decode the file content and auto detect the content encoding.
@@ -136,6 +140,7 @@ def iterate_s3_files(bucket: str, prefix: str) -> Generator:
currentIndice = 0
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
for obj in page.get("Contents", []):
+ currentIndice += 1
key = obj["Key"]
# skip the prefix with slash, which is the folder name
if key.endswith("/"):
@@ -145,49 +150,18 @@ def iterate_s3_files(bucket: str, prefix: str) -> Generator:
currentIndice, bucket, key
)
)
- if currentIndice != int(batchIndice):
+ if (currentIndice-1) // 100 != int(batchIndice):
logger.info(
"currentIndice: {}, batchIndice: {}, skip file: {}".format(
currentIndice, batchIndice, key
)
)
- currentIndice += 1
continue
- """
- WHY this code block is commented out? we used to record the processed object in DynamoDB in case of redundant operation for the same object
- """
- # # Truncate to seconds with round()
- # current_time = int(round(time.time()))
- # # Check for redundancy and expiry
- # response = table.query(
- # KeyConditionExpression=Key("ObjectKey").eq(key),
- # ScanIndexForward=False, # Sort by ProcessTimestamp in descending order
- # Limit=1, # We only need the latest record
- # )
-
- # # If the object is found and has not expired, skip processing
- # if (
- # response["Items"]
- # and response["Items"][0]["ExpiryTimestamp"] > current_time
- # ):
- # logger.info(f"Object {key} has not expired yet and will be skipped.")
- # continue
-
- # # Record the processing of the S3 object with an updated expiry timestamp, and each job only update single object in table. TODO, current assume the object will be handled successfully
- # expiry_timestamp = current_time + OBJECT_EXPIRY_TIME
- # try:
- # table.put_item(
- # Item={
- # "ObjectKey": key,
- # "ProcessTimestamp": current_time,
- # "Bucket": bucket,
- # "Prefix": "/".join(key.split("/")[:-1]),
- # "ExpiryTimestamp": expiry_timestamp,
- # }
- # )
- # except Exception as e:
- # logger.error(f"Error recording processed of S3 object {key}: {e}")
-
+ logger.info(
+ "Processing {} doc in {} batch, key: {}".format(
+ currentIndice, batchIndice, key
+ )
+ )
file_type = key.split(".")[-1].lower() # Extract file extension
response = s3.get_object(Bucket=bucket, Key=key)
file_content = response["Body"].read()
@@ -202,37 +176,27 @@ def iterate_s3_files(bucket: str, prefix: str) -> Generator:
if file_type == "txt":
yield "txt", decode_file_content(file_content), kwargs
- break
elif file_type == "csv":
# Update row count here, the default row count is 1
kwargs["csv_row_count"] = 1
yield "csv", decode_file_content(file_content), kwargs
- break
elif file_type == "html":
yield "html", decode_file_content(file_content), kwargs
- break
elif file_type in ["pdf"]:
yield "pdf", file_content, kwargs
- break
elif file_type in ["jpg", "png"]:
yield "image", file_content, kwargs
- break
elif file_type in ["docx", "doc"]:
yield "doc", file_content, kwargs
- break
elif file_type == "md":
yield "md", decode_file_content(file_content), kwargs
- break
elif file_type == "json":
yield "json", decode_file_content(file_content), kwargs
- break
elif file_type == "jsonl":
yield "jsonl", file_content, kwargs
- break
else:
logger.info(f"Unknown file type: {file_type}")
-
def batch_generator(generator, batch_size: int):
iterator = iter(generator)
while True:
@@ -244,7 +208,7 @@ def batch_generator(generator, batch_size: int):
def aos_injection(
content: List[Document],
- embeddingModelEndpointList: List[str],
+ embeddingModelEndpoint: str,
aosEndpoint: str,
index_name: str,
file_type: str,
@@ -270,8 +234,9 @@ def aos_injection(
Note:
"""
- embedding_list = sm_utils.create_embedding_with_multiple_model(
- embeddingModelEndpointList, region, file_type
+ print(f"embeddingModelEndpoint: {embeddingModelEndpoint}")
+ embeddings = sm_utils.create_embeddings_with_m3_model(
+ embeddingModelEndpoint, region, file_type
)
def chunk_generator(
@@ -337,37 +302,53 @@ def chunk_generator(
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def _aos_injection(document: Document) -> Document:
- # If user customize the index, use the customized index as high priority, NOTE the custom index will be created with default AOS mapping in LangChain, use API to create the index with customized mapping before running the job if you want to customize the mapping
- if aos_custom_index:
- index_name = aos_custom_index
- for embedding in embedding_list.values():
- document.metadata["embedding_endpoint_name"] = (
- embedding.endpoint_name
- )
- docsearch = OpenSearchVectorSearch(
- index_name=index_name,
- embedding_function=embedding,
- opensearch_url="https://{}".format(aosEndpoint),
- http_auth=awsauth,
- use_ssl=True,
- verify_certs=True,
- connection_class=RequestsHttpConnection,
- )
- logger.info(
- "Adding documents %s to OpenSearch with index %s",
- document,
- index_name,
- )
- # TODO: add endpoint name as a metadata of document
- try:
- # TODO, consider the max retry and initial backoff inside helper.bulk operation instead of using original LangChain
- docsearch.add_documents(documents=[document])
- except Exception as e:
- logger.info(
- f"Catch exception when adding document to OpenSearch: {e}"
- )
- logger.info("Retry statistics: %s", _aos_injection.retry.statistics)
+ document.metadata["embedding_endpoint_name"] = embeddingModelEndpoint
+ docsearch = OpenSearchVectorSearch(
+ index_name=index_name,
+ embedding_function=embeddings,
+ opensearch_url="https://{}".format(aosEndpoint),
+ http_auth=awsauth,
+ use_ssl=True,
+ verify_certs=True,
+ connection_class=RequestsHttpConnection,
+ )
+ # TODO: validate and update for m3 ep
+ # def add_texts_update_metadata(self, texts, metadatas=None, ids=None, bulk_size=500):
+ # embeddings = self.embedding_function.embed_documents(list(texts))
+ # if isinstance(embeddings[0], dict):
+ # embeddings_list = []
+ # metadata_list = []
+ # for m3_dict,metadata in zip(embeddings, metadatas):
+ # lexical_weights = m3_dict["lexical_weights"]
+ # colbert_vecs = m3_dict["colbert_vecs"]
+ # embeddings_list.append(m3_dict["dense_vecs"])
+ # metadata_list.append(metadata.append({'additional_vecs':{'lexical_weights':lexical_weights, 'colbert_vecs':colbert_vecs}}))
+ # embeddings = embeddings_list
+ # metadatas = metadata_list
+ # return self.__add(
+ # texts,
+ # embeddings,
+ # metadatas=metadatas,
+ # ids=ids,
+ # bulk_size=bulk_size,
+ # **kwargs,
+ # )
+ # docsearch.add_texts = functools.partial(add_texts_update_metadata, docsearch)
+ logger.info(
+ "Adding documents %s to OpenSearch with index %s",
+ document,
+ index_name,
+ )
+ # TODO: add endpoint name as a metadata of document
+ # try:
+ # TODO, consider the max retry and initial backoff inside helper.bulk operation instead of using original LangChain
+ docsearch.add_documents(documents=[document])
+ # except Exception as e:
+ # logger.info(
+ # f"Catch exception when adding document to OpenSearch: {e}"
+ # )
+ logger.info("Retry statistics: %s", _aos_injection.retry.statistics)
# logger.info("Adding documents %s to OpenSearch with index %s", document, index_name)
save_content_to_s3(s3, document, res_bucket, SplittingType.CHUNK.value)
@@ -377,96 +358,92 @@ def _aos_injection(document: Document) -> Document:
# Main function to be called by Glue job script
def main():
logger.info("Starting Glue job with passing arguments: %s", args)
- # Check if offline mode
- if offline == "true" or offline == "false":
- logger.info("Running in offline mode with consideration for large file size...")
- for file_type, file_content, kwargs in iterate_s3_files(s3_bucket, s3_prefix):
- try:
- if file_type == "json":
- kwargs["embeddings_model_info_list"] = embeddings_model_info_list
- kwargs["aos_index"] = aos_index
- kwargs["aosEndpoint"] = aosEndpoint
- kwargs["region"] = region
- kwargs["awsauth"] = awsauth
- kwargs["content_type"] = content_type
- kwargs["max_os_docs_per_put"] = MAX_OS_DOCS_PER_PUT
- res = cb_process_object(s3, file_type, file_content, **kwargs)
+ logger.info("Running in offline mode with consideration for large file size...")
+
+ embeddings_model_provider, embeddings_model_name, embeddings_model_dimensions = (
+ get_embedding_info(embeddingModelEndpoint)
+ )
+
+ for file_type, file_content, kwargs in iterate_s3_files(s3_bucket, s3_prefix):
+ try:
+ res = cb_process_object(s3, file_type, file_content, **kwargs)
+ for document in res:
+ save_content_to_s3(
+ s3, document, res_bucket, SplittingType.SEMANTIC.value
+ )
+
+ # the res is unified to list[Doucment] type, store the res to S3 for observation
+ # TODO, parse the metadata to embed with different index
+ if res:
+ logger.info("Result: %s", res)
+
+ aos_index = workspace_manager.update_workspace_open_search(
+ workspace_id,
+ embeddingModelEndpoint,
+ embeddings_model_provider,
+ embeddings_model_name,
+ embeddings_model_dimensions,
+ ["zh"],
+ [file_type],
+ )
+
+ gen_chunk_flag = False if file_type == "csv" else True
+ if file_type in supported_file_types:
+ aos_injection(
+ res,
+ embeddingModelEndpoint,
+ aosEndpoint,
+ aos_index,
+ file_type,
+ gen_chunk=gen_chunk_flag,
+ )
+
+ if qa_enhancement == "true":
+ enhanced_prompt_list = []
+ # iterate the document to get the QA pairs
for document in res:
- save_content_to_s3(
- s3, document, res_bucket, SplittingType.SEMANTIC.value
+ # Define your prompt or else it uses default prompt
+ prompt = ""
+ # Make sure the document is Document object
+ logger.info(
+ "Enhancing document type: {} and content: {}".format(
+ type(document), document
+ )
)
-
- # the res is unified to list[Doucment] type, store the res to S3 for observation
- # TODO, parse the metadata to embed with different index
- if res:
- logger.info("Result: %s", res)
- if file_type == "csv":
- # CSV page document has been splited into chunk, no more spliting is needed
- aos_injection(
- res,
- embeddingModelEndpointList,
- aosEndpoint,
- aos_index,
- file_type,
- gen_chunk=False,
+ ewb = EnhanceWithBedrock(prompt, document)
+ # This is should be optional for the user to choose the chunk size
+ document_list = ewb.SplitDocumentByTokenNum(
+ document, ENHANCE_CHUNK_SIZE
)
- elif file_type in ["pdf", "txt", "doc", "md", "html", "json", "jsonl"]:
+ for document in document_list:
+ enhanced_prompt_list = ewb.EnhanceWithClaude(
+ prompt, document, enhanced_prompt_list
+ )
+ logger.info(f"Enhanced prompt: {enhanced_prompt_list}")
+
+ if len(enhanced_prompt_list) > 0:
+ for document in enhanced_prompt_list:
+ save_content_to_s3(
+ s3,
+ document,
+ res_bucket,
+ SplittingType.QA_ENHANCEMENT.value,
+ )
aos_injection(
- res,
- embeddingModelEndpointList,
+ enhanced_prompt_list,
+ embeddingModelEndpoint,
aosEndpoint,
aos_index,
- file_type,
+ "qa",
)
- if qa_enhancement == "true":
- enhanced_prompt_list = []
- # iterate the document to get the QA pairs
- for document in res:
- # Define your prompt or else it uses default prompt
- prompt = ""
- # Make sure the document is Document object
- logger.info(
- "Enhancing document type: {} and content: {}".format(
- type(document), document
- )
- )
- ewb = EnhanceWithBedrock(prompt, document)
- # This is should be optional for the user to choose the chunk size
- document_list = ewb.SplitDocumentByTokenNum(
- document, ENHANCE_CHUNK_SIZE
- )
- for document in document_list:
- enhanced_prompt_list = ewb.EnhanceWithClaude(
- prompt, document, enhanced_prompt_list
- )
- logger.info(f"Enhanced prompt: {enhanced_prompt_list}")
-
- if len(enhanced_prompt_list) > 0:
- for document in enhanced_prompt_list:
- save_content_to_s3(
- s3,
- document,
- res_bucket,
- SplittingType.QA_ENHANCEMENT.value,
- )
- aos_injection(
- enhanced_prompt_list,
- embeddingModelEndpointList,
- aosEndpoint,
- aos_index,
- "qa",
- )
-
- except Exception as e:
- logger.error(
- "Error processing object %s: %s",
- kwargs["bucket"] + "/" + kwargs["key"],
- e,
- )
- traceback.print_exc()
- else:
- logger.info("Running in online mode, assume file number is small...")
+ except Exception as e:
+ logger.error(
+ "Error processing object %s: %s",
+ kwargs["bucket"] + "/" + kwargs["key"],
+ e,
+ )
+ traceback.print_exc()
if __name__ == "__main__":
diff --git a/source/model/instruct/internlm2-chat-20b-global-lmdeploy/code/model.py b/source/model/instruct/internlm2-chat-20b-global-lmdeploy/code/model.py
new file mode 100644
index 00000000..e5c656cf
--- /dev/null
+++ b/source/model/instruct/internlm2-chat-20b-global-lmdeploy/code/model.py
@@ -0,0 +1,116 @@
+import os
+os.environ['PYTHONUNBUFFERED'] = "1"
+import traceback
+import sys
+import torch
+import gc
+from typing import List,Tuple
+try:
+ from transformers.generation.streamers import BaseStreamer
+except: # noqa # pylint: disable=bare-except
+ BaseStreamer = None
+import queue
+import threading
+import time
+from queue import Empty
+from djl_python import Input, Output
+import torch
+import json
+import logging
+from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
+# from transformers.generation.utils import GenerationConfig
+import traceback
+from lmdeploy import pipeline, TurbomindEngineConfig,GenerationConfig
+from lmdeploy.model import ChatTemplateConfig
+import lmdeploy
+logger = logging.getLogger("sagemaker-inference")
+request_lock = threading.Lock()
+
+
+pipe = None
+
+def get_model(properties):
+ model_dir = properties['model_dir']
+ model_path = os.path.join(model_dir, 'hf_model/')
+ if "model_id" in properties:
+ model_path = properties['model_id']
+ logger.info(f'properties: {properties}')
+ logger.info(f'model_path: {model_path}')
+ # local_rank = int(os.getenv('LOCAL_RANK', '0'))
+ engine_config = TurbomindEngineConfig(
+ model_format='awq',
+ rope_scaling_factor=2.0,
+ session_len=160000,
+ cache_max_entry_count=0.2
+ )
+ pipe = pipeline(
+ model_path,
+ model_name="internlm2-chat-20b",
+ backend_config=engine_config
+ )
+ return pipe
+
+def _default_stream_output_formatter(token_texts):
+ if isinstance(token_texts,Exception):
+ token_texts = {'error_msg':str(token_texts)}
+ else:
+ token_texts = {"outputs": token_texts}
+ json_encoded_str = json.dumps(token_texts) + "\n"
+ return bytearray(json_encoded_str.encode("utf-8"))
+
+def generate(pipe,**body):
+ query = body.pop('query')
+ stream = body.pop('stream',False)
+ stop_words = body.pop('stop_tokens',None)
+ if stop_words:
+ assert isinstance(stop_words,list), stop_words
+ body['stop_words'] = stop_words + ['<|im_end|>', '<|action_end|>']
+ # body.update({"do_preprocess": False})
+ timeout = body.pop('timeout',60)
+ gen_config = GenerationConfig(**body)
+
+ stream_generator = pipe.stream_infer([query],gen_config=gen_config,do_preprocess=False)
+
+ def _generator_helper(gen):
+ try:
+ for i in gen:
+ yield i.text
+ finally:
+ traceback.clear_frames(sys.exc_info()[2])
+ gc.collect()
+ torch.cuda.empty_cache()
+ stream_generator = _generator_helper(stream_generator)
+ if stream:
+ return stream_generator
+ r = ""
+ for i in stream_generator:
+ r += i
+ return r
+
+
+def _handle(inputs: Input) -> None:
+ torch.cuda.empty_cache()
+ global pipe
+ if pipe is None:
+ pipe = get_model(inputs.get_properties())
+
+ if inputs.is_empty():
+ # Model server makes an empty call to warmup the model on startup
+ return None
+ body = inputs.get_as_json()
+
+ logger.info(f'body: {body}')
+ stream = body.get('stream',False)
+ response = generate(pipe,**body)
+ if stream:
+ return Output().add_stream_content(response,output_formatter=_default_stream_output_formatter)
+ else:
+ return Output().add_as_json(response)
+
+
+def handle(inputs: Input) -> None:
+ task_request_time = time.time()
+ logger.info(f'recieve request task: {task_request_time},{inputs}')
+ with request_lock:
+ logger.info(f'executing request task, wait time: {time.time()-task_request_time}s')
+ return _handle(inputs)
diff --git a/source/model/instruct/internlm2-chat-20b-global-lmdeploy/code/requirements.txt b/source/model/instruct/internlm2-chat-20b-global-lmdeploy/code/requirements.txt
new file mode 100644
index 00000000..5a1f8139
--- /dev/null
+++ b/source/model/instruct/internlm2-chat-20b-global-lmdeploy/code/requirements.txt
@@ -0,0 +1,7 @@
+lmdeploy[all]==0.2.4
+torch==2.1.2
+sentencepiece==0.1.99
+accelerate==0.25.0
+bitsandbytes==0.41.1
+transformers==4.37.1
+einops==0.7.0
diff --git a/source/model/instruct/internlm2-chat-20b-global-lmdeploy/code/serving.properties b/source/model/instruct/internlm2-chat-20b-global-lmdeploy/code/serving.properties
new file mode 100644
index 00000000..4180d006
--- /dev/null
+++ b/source/model/instruct/internlm2-chat-20b-global-lmdeploy/code/serving.properties
@@ -0,0 +1,3 @@
+engine=Python
+option.enable_streaming=true
+option.s3url=s3://aws-gcr-csdc-atl-exp-us-west-2/aigc-llm-models/internlm2-chat-20b-4bits/model
diff --git a/source/model/instruct/internlm2-chat-20b-global-lmdeploy/prepare_model.sh b/source/model/instruct/internlm2-chat-20b-global-lmdeploy/prepare_model.sh
new file mode 100644
index 00000000..a505a126
--- /dev/null
+++ b/source/model/instruct/internlm2-chat-20b-global-lmdeploy/prepare_model.sh
@@ -0,0 +1,9 @@
+python ../prepare_model.py \
+ --hf_model_id 'internlm/internlm2-chat-20b-4bits' \
+ --hf_model_local_dir internlm2-chat-20b-4bits-lmdeploy \
+ --hf_model_revision "main" \
+ --model_artifact_dir code \
+ --model_artifact_tar_name llm_model.tar.gz \
+ --s3_bucket "aws-gcr-csdc-atl-exp-us-west-2" \
+ --hf_model_s3_prefix aigc-llm-models/internlm2-chat-20b-4bits-lmdeploy \
+ --model_artifact_s3_prefix aigc-llm-models/internlm2-chat-20b-4bits-lmdeploy_deploy_code \
\ No newline at end of file
diff --git a/source/model/instruct/internlm2-chat-7b-cn-exllamav2/code/model.py b/source/model/instruct/internlm2-chat-7b-cn-exllamav2/code/model.py
new file mode 100644
index 00000000..158e3fa4
--- /dev/null
+++ b/source/model/instruct/internlm2-chat-7b-cn-exllamav2/code/model.py
@@ -0,0 +1,168 @@
+import time
+import sys, os
+os.environ['PYTHONUNBUFFERED'] = "1"
+import traceback
+import sys
+import torch
+import gc
+from typing import List,Tuple
+import logging
+try:
+ from transformers.generation.streamers import BaseStreamer
+except: # noqa # pylint: disable=bare-except
+ BaseStreamer = None
+import queue
+import threading
+import time
+from queue import Empty
+from djl_python import Input, Output
+import torch
+import json
+import types
+import threading
+from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
+# from transformers.generation.utils import GenerationConfig
+import traceback
+from transformers import AutoTokenizer,GPTQConfig,AutoModelForCausalLM
+
+from exllamav2 import (
+ ExLlamaV2,
+ ExLlamaV2Config,
+ ExLlamaV2Cache,
+ ExLlamaV2Tokenizer,
+)
+
+from exllamav2.generator import (
+ ExLlamaV2StreamingGenerator,
+ ExLlamaV2Sampler
+)
+handle_lock = threading.Lock()
+logger = logging.getLogger("sagemaker-inference")
+logger.info(f'logger handlers: {logger.handlers}')
+
+generator = None
+tokenizer = None
+
+
+def new_decode(self, ids, decode_special_tokens = False):
+ ori_decode = tokenizer.decode
+ return ori_decode(ids, decode_special_tokens = True)
+
+def get_model(properties):
+ model_dir = properties['model_dir']
+ model_path = os.path.join(model_dir, 'hf_model/')
+ if "model_id" in properties:
+ model_path = properties['model_id']
+ logger.info(f'properties: {properties}')
+ logger.info(f'model_path: {model_path}')
+ # local_rank = int(os.getenv('LOCAL_RANK', '0'))
+ model_directory = model_path
+
+ config = ExLlamaV2Config()
+ config.model_dir = model_directory
+ config.prepare()
+
+ model = ExLlamaV2(config)
+ logger.info("Loading model: " + model_directory)
+
+ cache = ExLlamaV2Cache(model, lazy = True)
+ model.load_autosplit(cache)
+
+ tokenizer = ExLlamaV2Tokenizer(config)
+
+ generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
+
+ return tokenizer,generator
+
+def _default_stream_output_formatter(token_texts):
+ if isinstance(token_texts,Exception):
+ token_texts = {'error_msg':str(token_texts)}
+ else:
+ token_texts = {"outputs": token_texts}
+ json_encoded_str = json.dumps(token_texts) + "\n"
+ return bytearray(json_encoded_str.encode("utf-8"))
+
+def generate(**body):
+ query = body.pop('query')
+ stream = body.pop('stream',False)
+ stop_words = body.pop('stop_tokens',None)
+
+ stop_token_ids = [
+ tokenizer.eos_token_id,
+ tokenizer.encode('<|im_end|>',encode_special_tokens=True).tolist()[0][0]
+ ]
+
+ if stop_words:
+ assert isinstance(stop_words,list), stop_words
+ for stop_word in stop_words:
+ stop_token_ids.append(tokenizer.encode(stop_word,encode_special_tokens=True).tolist()[0][0])
+
+ # body.update({"do_preprocess": False})
+ timeout = body.pop('timeout',60)
+ settings = ExLlamaV2Sampler.Settings()
+ settings.temperature = body.get('temperature',0.1)
+ settings.top_k = body.get('top_k',50)
+ settings.top_p = body.get('top_p',0.8)
+ settings.top_a = body.get('top_a',0.0)
+ settings.token_repetition_penalty = 1.0
+ # tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]
+ # settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
+ max_new_tokens = body.get('max_new_tokens',500)
+ input_ids = tokenizer.encode(query,encode_special_tokens=True,add_bos=True)
+ prompt_tokens = input_ids.shape[-1]
+ generator.warmup()
+ generator.set_stop_conditions(stop_token_ids)
+
+ generator.begin_stream(input_ids, settings)
+
+ def _generator_helper():
+ try:
+ generated_tokens = 0
+ while True:
+ chunk, eos, _ = generator.stream()
+ generated_tokens += 1
+ yield chunk
+ if eos or generated_tokens == max_new_tokens: break
+
+ finally:
+ traceback.clear_frames(sys.exc_info()[2])
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ stream_generator = _generator_helper()
+ if stream:
+ return stream_generator
+ r = ""
+ for i in stream_generator:
+ r += i
+ return r
+
+
+def _handle(inputs) -> None:
+ if inputs.is_empty():
+ logger.info('inputs is empty')
+ # Model server makes an empty call to warmup the model on startup
+ return None
+ torch.cuda.empty_cache()
+ body = inputs.get_as_json()
+ stream = body.get('stream',False)
+ logger.info(f'body: {body}')
+ response = generate(**body)
+ if stream:
+ return Output().add_stream_content(response,output_formatter=_default_stream_output_formatter)
+ else:
+ return Output().add_as_json(response)
+
+
+def handle(inputs: Input) -> None:
+ task_request_time = time.time()
+ logger.info(f'recieve request task: {task_request_time},{inputs}')
+ with handle_lock:
+ global generator,tokenizer
+ if generator is None:
+ tokenizer, generator = get_model(inputs.get_properties())
+ tokenizer.decode = types.MethodType(new_decode, tokenizer)
+ logger.info(f'executing request task, wait time: {time.time()-task_request_time}s')
+ return _handle(inputs)
+
+
diff --git a/source/model/instruct/internlm2-chat-7b-cn-exllamav2/code/requirements.txt b/source/model/instruct/internlm2-chat-7b-cn-exllamav2/code/requirements.txt
new file mode 100644
index 00000000..d3604067
--- /dev/null
+++ b/source/model/instruct/internlm2-chat-7b-cn-exllamav2/code/requirements.txt
@@ -0,0 +1,7 @@
+exllamav2==0.0.14
+torch==2.1.2
+sentencepiece==0.1.99
+accelerate==0.25.0
+bitsandbytes==0.41.1
+transformers==4.37.1
+einops==0.7.0
diff --git a/source/model/instruct/internlm2-chat-7b-cn-exllamav2/code/serving.properties b/source/model/instruct/internlm2-chat-7b-cn-exllamav2/code/serving.properties
new file mode 100644
index 00000000..b529a0b6
--- /dev/null
+++ b/source/model/instruct/internlm2-chat-7b-cn-exllamav2/code/serving.properties
@@ -0,0 +1,3 @@
+engine=Python
+option.enable_streaming=true
+option.s3url=s3://sagemaker-cn-north-1-256374081253/aigc-llm-models/internlm2-chat-7b-llama-exl2/model
diff --git a/source/model/instruct/internlm2-chat-7b-cn-exllamav2/prepare_model.sh b/source/model/instruct/internlm2-chat-7b-cn-exllamav2/prepare_model.sh
new file mode 100644
index 00000000..cebfb669
--- /dev/null
+++ b/source/model/instruct/internlm2-chat-7b-cn-exllamav2/prepare_model.sh
@@ -0,0 +1,9 @@
+python ../prepare_model.py \
+ --hf_model_id bartowski/internlm2-chat-7b-llama-exl2 \
+ --hf_model_local_dir internlm2-chat-7b-llama-exl2 \
+ --hf_model_revision "4_25" \
+ --model_artifact_dir code \
+ --model_artifact_tar_name llm_model.tar.gz \
+ --s3_bucket "sagemaker-cn-north-1-256374081253" \
+ --hf_model_s3_prefix aigc-llm-models/internlm2-chat-7b-llama-exl2 \
+ --model_artifact_s3_prefix aigc-llm-models/internlm2-chat-7b-llama-exl2_deploy_code \
\ No newline at end of file
diff --git a/source/model/instruct/prepare_model.py b/source/model/instruct/prepare_model.py
new file mode 100644
index 00000000..4dad7979
--- /dev/null
+++ b/source/model/instruct/prepare_model.py
@@ -0,0 +1,55 @@
+import argparse
+import os
+from huggingface_hub import snapshot_download
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='prepare model for djl serving')
+ parser.add_argument('--hf_model_id')
+ parser.add_argument("--hf_model_local_dir",help='to save hf model')
+ parser.add_argument("--hf_model_revision",default="main",help='to save hf model')
+ parser.add_argument("--model_artifact_dir",help='path to model artifacts')
+ parser.add_argument("--model_artifact_tar_name",default="model.tar.gz")
+ parser.add_argument("--s3_bucket")
+ parser.add_argument("--hf_model_s3_prefix")
+ parser.add_argument("--model_artifact_s3_prefix")
+ return parser.parse_args()
+
+def download_hf_model_to_local(args):
+ hf_model_id = args.hf_model_id
+ hf_model_local_dir = args.hf_model_local_dir
+ hf_model_revision = args.hf_model_revision
+ snapshot_download(
+ hf_model_id,
+ local_dir=hf_model_local_dir,
+ local_dir_use_symlinks=False,
+ revision=hf_model_revision
+ )
+
+def tar_model_artifact(args):
+ print(f'tar {args.model_artifact_dir} to {args.model_artifact_tar_name}')
+ os.system(f'tar czvf {args.model_artifact_tar_name} {args.model_artifact_dir}')
+
+def push_to_s3(args):
+ # push hf model
+ print(f'push {args.hf_model_local_dir} to s3://{args.s3_bucket}/{args.hf_model_s3_prefix}')
+ os.system(f"aws s3 cp --recursive {args.hf_model_local_dir} s3://{args.s3_bucket}/{args.hf_model_s3_prefix}")
+ # push model artifacts
+ print(f'push {args.model_artifact_tar_name} to s3://{args.s3_bucket}/{args.model_artifact_s3_prefix}/')
+ os.system(f"aws s3 cp {args.model_artifact_tar_name} s3://{args.s3_bucket}/{args.model_artifact_s3_prefix}/")
+
+def main():
+ args = parse_args()
+ print(f'args: {args}')
+ # download hf model to local
+ download_hf_model_to_local(args)
+ # tar model artifacts
+ tar_model_artifact(args)
+ # push to s3
+ push_to_s3(args)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
diff --git a/source/panel/evaluator/benchmark.py b/source/panel/evaluator/benchmark.py
index eb9c6138..92916429 100644
--- a/source/panel/evaluator/benchmark.py
+++ b/source/panel/evaluator/benchmark.py
@@ -89,6 +89,8 @@
# model_id = "amazon.titan-text-express-v1", region_name="us-east-1"
)
+_openai_embedding = OpenAIEmbeddings()
+
def csdc_markdown_loader(file_path: str) -> List[Document]:
# read content from file_path
with open(file_path, "r") as f:
@@ -150,7 +152,7 @@ def langchain_unstructured_loader(file_path: str) -> List[Document]:
"""
loader = UnstructuredFileLoader(file_path, mode="elements")
docs = loader.load()
- logger.debug("unstructured load data: {}".format(docs))
+ logger.info("unstructured load data: {} type: {}".format(docs, type(docs)))
return docs
def parse_log_to_document_list(log_content: str) -> List[Document]:
@@ -380,7 +382,8 @@ def bedrock_embedding(index: str, docs: List[Document]) -> List[List[str]]:
opensearch_vector_search = OpenSearchVectorSearch(
opensearch_url="https://localhost:9200",
index_name=index,
- embedding_function=_bedrock_embedding,
+ # embedding_function=_bedrock_embedding,
+ embedding_function=_openai_embedding,
http_auth=("admin", "admin"),
use_ssl = False,
verify_certs = False,
@@ -393,7 +396,8 @@ def bedrock_embedding(index: str, docs: List[Document]) -> List[List[str]]:
for batch in batches:
for doc in batch:
res = opensearch_vector_search.add_embeddings(
- text_embeddings = [(doc.page_content, _bedrock_embedding.embed_documents([doc.page_content])[0])],
+ # text_embeddings = [(doc.page_content, _bedrock_embedding.embed_documents([doc.page_content])[0])],
+ text_embeddings = [(doc.page_content, _openai_embedding.embed_documents([doc.page_content])[0])],
metadatas = None,
ids = None,
bulk_size = 1024,
@@ -505,7 +509,8 @@ def local_aos_retriever(index: str, query: str, size: int = 10) -> List[Tuple[Do
opensearch_vector_search = OpenSearchVectorSearch(
opensearch_url="https://localhost:9200",
index_name=index,
- embedding_function=_bedrock_embedding,
+ # embedding_function=_bedrock_embedding,
+ embedding_function=_openai_embedding,
http_auth=("admin", "admin"),
use_ssl = False,
verify_certs = False,
@@ -609,9 +614,10 @@ def testdata_generate(doc: Document, llm: str = "bedrock", embedding: str = "bed
chat_qa=chat_qa,
)
+ logger.info("doc input to testdata_generate: {} with type: {}".format(doc, type(doc)))
testset = test_generator.generate(doc, test_size=test_size)
test_df = testset.to_pandas()
- logger.debug("testdata head: {}".format(test_df.head()))
+ logger.info("testdata head: {}".format(test_df.head()))
# Saving to a csv and txt file for debugging purpose
test_df.to_csv('test_data.csv', index=False)
@@ -859,42 +865,43 @@ def batch_generator(generator, batch_size: int):
9. average similarity score of retrival
10. average time of retrival
"""
+ loader_res = langchain_unstructured_loader("pdf-sample-01.pdf")
+ # loader_res = csdc_unstructured_loader("pdf-sample-01.pdf")
+ question_list, question_type_list = testdata_generate(loader_res, llm="openai", embedding="openai", test_size=10)
+
# initialization of workflow executor
- legacy = WorkflowExecutor()
- legacy.update_component('loaders', langchain_unstructured_loader, 'add')
- legacy.update_component('splitters', langchain_recursive_splitter, 'add')
- legacy.update_component('embedders', bedrock_embedding, 'add')
- legacy.update_component('retrievers', local_aos_retriever, 'add')
- legacy.update_component('evaluators', langchain_evaluator, 'add')
+ # legacy = WorkflowExecutor()
+ # legacy.update_component('loaders', langchain_unstructured_loader, 'add')
+ # legacy.update_component('splitters', langchain_recursive_splitter, 'add')
+ # legacy.update_component('embedders', bedrock_embedding, 'add')
+ # legacy.update_component('retrievers', local_aos_retriever, 'add')
+ # legacy.update_component('evaluators', langchain_evaluator, 'add')
# response = legacy.execute_workflow("pdf-sample-01.pdf", "请介绍什么是kindle以及它的主要功能?")
- # loader_res = langchain_unstructured_loader("pdf-sample-01.pdf")
- # question_list, question_type_list = testdata_generate(loader_res, llm="openai", embedding="openai", test_size=10)
-
# workaround for inconsistent network issue, if you have already generated the test data as sample schema below
"""
question,ground_truth_context,ground_truth,question_type,episode_done
How can you navigate to the next page on the screen?,['- 要翻到下一页,请用手指在屏幕上从右往左滑动。\n- 要翻到上一页,请用手指在屏幕上从左往右滑动。\n- 您还可以使用屏幕一侧的控件来翻页。'],"['To navigate to the next page on the screen, you can swipe your finger from right to left on the screen.']",simple,True
What are the steps for restarting an unresponsive or non-turning on Kindle device?,['- 如果您的 Kindle 无法开机或使用过程中停止响应而需要重启,请按住电源开关 7 秒,直 至【电源】对话框出现,然后选择【重新启动】。\n- 如果【电源】对话框不出现,请按住电 源开关 40 秒或直至 LED 灯停止闪烁。'],"['The steps for restarting an unresponsive or non-turning on Kindle device are as follows:\n1. Press and hold the power button for 7 seconds until the ""Power"" dialog box appears.\n2. Select ""Restart"" from the dialog box.\n3. If the ""Power"" dialog box does not appear, press and hold the power button for 40 seconds or until the LED light stops flashing.']",conditional,True
"""
- question_list = []
- with open('test_data.csv', 'r') as f:
- reader = csv.reader(f)
- for row in reader:
- question_list.append(row[0])
-
- # iterate the question list to execute the workflow
- response_list = []
- for question in question_list:
- response = legacy.execute_workflow("pdf-sample-01.pdf", question, skip=False)
- logger.info("test of legacy workflow: {}".format(response))
- response_list.append(response)
-
- logger.info("response_list: {}".format(response_list))
-
- # visualize the summary
- legacy.summary_viz(response_list)
+ # question_list = []
+ # with open('test_data.csv', 'r') as f:
+ # reader = csv.reader(f)
+ # for row in reader:
+ # question_list.append(row[0])
+
+ # # iterate the question list to execute the workflow
+ # response_list = []
+ # for question in question_list:
+ # response = legacy.execute_workflow("pdf-sample-01.pdf", question, skip=True)
+ # logger.info("test of legacy workflow: {}".format(response))
+ # response_list.append(response)
+
+ # logger.info("response_list: {}".format(response_list))
+
+ # # visualize the summary
+ # legacy.summary_viz(response_list)
# initialization of workflow executor
# csdc = WorkflowExecutor()
diff --git a/test/websocket_api_test.py b/test/websocket_api_test.py
index e5edd6b8..426d86ca 100644
--- a/test/websocket_api_test.py
+++ b/test/websocket_api_test.py
@@ -9,8 +9,9 @@
import json
# find ws_url from api gateway
-# ws_url = "wss://omjou492fe.execute-api.us-west-2.amazonaws.com/prod/"
ws_url = "wss://2ogbgobue2.execute-api.us-west-2.amazonaws.com/prod/"
+# ws_url = "wss://2ogbgobue2.execute-api.us-west-2.amazonaws.com/v1"
+# wss://2ogbgobue2.execute-api.us-west-2.amazonaws.com/prod/
ws = create_connection(ws_url)
question_library = [
@@ -25,43 +26,75 @@
"Amazon EC2 提供了哪些功能来支持不同区域之间的数据恢复?"
]
+
+# endpoint_name = 'internlm2-chat-20b-4bits-2024-03-04-06-32-53-653'
+# model_id = "internlm2-chat-20b"
+entry_type = "market_chain"
+workspace_ids = ["aos_index_mkt_faq_qq","aos_index_acts_qd"]
+
body = {
- "action": "sendMessage",
- "model": "knowledge_qa",
- # "messages": [{"role": "user","content": question_library[-1]}],
- # "messages": [{"role": "user","content": question_library[-1]}],
- "messages": [{"role": "user","content": '什么是Bedrock?'}],
- "temperature": 0.7,
- "type" : "market_chain",
+ "get_contexts": True,
+ "type" : entry_type,
"retriever_config":{
- "using_whole_doc": False,
- "chunk_num": 2,
- },
- # "enable_q_q_match": True,
- # "enable_debug": False,
- # "llm_model_id":'anthropic.claude-v2:1',
- "get_contexts":True,
- "generator_llm_config":{
- "model_kwargs":{
- "max_new_tokens": 1000,
- "temperature": 0.01,
- "top_p": 0.9,
- "timeout":120
- },
- "model_id": "internlm2-chat-7b",
- # "endpoint_name": "instruct-internlm2-chat-7b-f7dc2",
- "endpoint_name": "internlm2-chat-7b-4bits-2024-02-28-07-08-57-839",#"baichuan2-13b-chat-4bits-2024-01-28-15-46-43-013",
- "context_num": 1
- }
- # "session_id":f"test_{int(time.time())}"
+ "qq_config": {
+ "q_q_match_threshold": 0.8,
+ },
+ "qd_config":{
+ "qd_match_threshold": 2,
+ # "using_whole_doc": True
+ },
+ "workspace_ids": workspace_ids
+ }
}
+# body = {
+# "get_contexts": True,
+# "model": "knowledge_qa",
+# # "messages": [{"role": "user","content": question_library[-1]}],
+# # "messages": [{"role": "user","content": question_library[-1]}],
+# "messages": [{"role": "user","content": '什么是Bedrock?', "custom_message_id": f"test_dashboard_{time.time()}"}],
+# # "temperature": 0.7,
+# "type" : "market_chain",
+# "retriever_config":{
+# "using_whole_doc": False,
+# "chunk_num": 2,
+# },
+# # "enable_q_q_match": True,
+# # "enable_debug": False,
+# # "llm_model_id":'anthropic.claude-v2:1',
+# "get_contexts":True,
+# "generator_llm_config":{
+# "model_kwargs":{
+# "max_new_tokens": 1000,
+# "temperature": 0.01,
+# "top_p": 0.9,
+# "timeout":120
+# },
+# "llm_model_id": "internlm2-chat-7b",
+# # "endpoint_name": "instruct-internlm2-chat-7b-f7dc2",
+# "llm_model_endpoint_name": "internlm2-chat-20b-4bits-2024-03-04-06-32-53-653",#"baichuan2-13b-chat-4bits-2024-01-28-15-46-43-013",
+# "context_num": 1
+# },
+# "model_kwargs":{
+# "max_new_tokens": 1000,
+# "temperature": 0.01,
+# "top_p": 0.9,
+# "timeout":120
+# },
+# "llm_model_id": "internlm2-chat-20b",
+# # "endpoint_name": "instruct-internlm2-chat-7b-f7dc2",
+# "llm_model_endpoint_name": "internlm2-chat-20b-4bits-2024-03-04-06-32-53-653",#"baichuan2-13b-chat-4bits-2024-01-28-15-46-43-013",
+# "context_num": 1,
+# "custom_message_id": f"test_dashboard_{int(time.time())}"
+# # "session_id":f"test_{int(time.time())}"
+# }
+
-body.update({"retriever_top_k": 1,
- "chunk_num": 2,
- "using_whole_doc": False,
- "reranker_top_k": 10,
- "enable_reranker": True})
+# body.update({"retriever_top_k": 1,
+# "chunk_num": 2,
+# "using_whole_doc": False,
+# "reranker_top_k": 10,
+# "reranker_type": "no_reranker"})
# body = {