From 941ccde2b08a56df38620ad9fc0ee7835869661f Mon Sep 17 00:00:00 2001 From: Karen Shaw Date: Fri, 23 Feb 2024 22:19:26 +0000 Subject: [PATCH] Update chat handler for Opensearch Update permissions on chat websocket function Add AWS4Auth to opensearch client Tweak EventConfig to make chat work with OpenSearch --- .gitignore | 2 + Makefile | 2 + chat/dependencies/requirements.txt | 7 ++- chat/src/content_handler.py | 36 ++++++++++++ chat/src/event_config.py | 91 ++++++++++++++---------------- chat/src/handlers/chat.py | 13 +++-- chat/src/helpers/prompts.py | 4 +- chat/src/helpers/response.py | 23 ++++++-- chat/src/requirements.txt | 7 ++- chat/src/setup.py | 73 ++++++++++++------------ chat/src/websocket.py | 6 +- chat/template.yaml | 24 ++++++-- chat/test/handlers/test_chat.py | 7 ++- chat/test/helpers/test_prompts.py | 6 +- chat/test/test_event_config.py | 23 +------- template.yaml | 13 ++--- 16 files changed, 194 insertions(+), 143 deletions(-) create mode 100644 chat/src/content_handler.py diff --git a/.gitignore b/.gitignore index 3c9b74f5..da679d9f 100644 --- a/.gitignore +++ b/.gitignore @@ -221,6 +221,8 @@ $RECYCLE.BIN/ /docs/docs/spec/openapi.json /docs/site +.venv + .vscode /samconfig.toml /samconfig.yaml diff --git a/Makefile b/Makefile index dded3f4f..6c3a5a06 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,8 @@ cover-html-python: deps-python cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage html --skip-empty style-python: deps-python cd chat && ruff check . +style-python-fix: deps-python + cd chat && ruff check --fix . test-python: deps-python cd chat && export SKIP_WEAVIATE_SETUP=True && PYTHONPATH=src:test && python -m unittest discover -v python-version: diff --git a/chat/dependencies/requirements.txt b/chat/dependencies/requirements.txt index 6bee442a..f80af593 100644 --- a/chat/dependencies/requirements.txt +++ b/chat/dependencies/requirements.txt @@ -1,8 +1,11 @@ boto3~=1.34.13 -langchain~=0.0.208 +langchain~=0.1.8 +langchain-community openai~=0.27.8 +opensearch-py pyjwt~=2.6.0 python-dotenv~=1.0.0 +requests +requests-aws4auth tiktoken~=0.4.0 -weaviate-client~=3.19.2 wheel~=0.40.0 \ No newline at end of file diff --git a/chat/src/content_handler.py b/chat/src/content_handler.py new file mode 100644 index 00000000..b75f98b9 --- /dev/null +++ b/chat/src/content_handler.py @@ -0,0 +1,36 @@ +import json +from typing import Dict, List +from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler + +class ContentHandler(EmbeddingsContentHandler): + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes: + """ + Transforms the input into bytes that can be consumed by SageMaker endpoint. + Args: + inputs: List of input strings. + model_kwargs: Additional keyword arguments to be passed to the endpoint. + Returns: + The transformed bytes input. + """ + # Example: inference.py expects a JSON string with a "inputs" key: + input_str = json.dumps({"inputs": inputs, **model_kwargs}) + return input_str.encode("utf-8") + + def transform_output(self, output: bytes) -> List[List[float]]: + """ + Transforms the bytes output from the endpoint into a list of embeddings. + Args: + output: The bytes output from SageMaker endpoint. + Returns: + The transformed output - list of embeddings + Note: + The length of the outer list is the number of input strings. + The length of the inner lists is the embedding dimension. + """ + # Example: inference.py returns a JSON string with the list of + # embeddings in a "vectors" key: + response_json = json.loads(output.read().decode("utf-8")) + return [response_json["embedding"]] \ No newline at end of file diff --git a/chat/src/event_config.py b/chat/src/event_config.py index 5c7762b3..e20148f9 100644 --- a/chat/src/event_config.py +++ b/chat/src/event_config.py @@ -5,8 +5,8 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.prompts import PromptTemplate from setup import ( - weaviate_client, - weaviate_vector_store, + opensearch_client, + opensearch_vector_store, openai_chat_client, ) from typing import List @@ -15,17 +15,14 @@ from helpers.prompts import document_template, prompt_template from websocket import Websocket - CHAIN_TYPE = "stuff" DOCUMENT_VARIABLE_NAME = "context" -INDEX_NAME = "DCWork" -K_VALUE = 10 +K_VALUE = 5 MAX_K = 100 TEMPERATURE = 0.2 TEXT_KEY = "title" VERSION = "2023-07-01-preview" - @dataclass class EventConfig: """ @@ -33,6 +30,12 @@ class EventConfig: Default values are set for the following properties which can be overridden in the payload message. """ + DEFAULT_ATTRIBUTES = ["accession_number", "alternate_title", "api_link", "canonical_link", "caption", "collection", + "contributor", "date_created", "date_created_edtf", "description", "genre", "id", "identifier", + "keywords", "language", "notes", "physical_description_material", "physical_description_size", + "provenance", "publisher", "rights_statement", "subject", "table_of_contents", "thumbnail", + "title", "visibility", "work_type"] + api_token: ApiToken = field(init=False) attributes: List[str] = field(init=False) azure_endpoint: str = field(init=False) @@ -41,7 +44,6 @@ class EventConfig: deployment_name: str = field(init=False) document_prompt: PromptTemplate = field(init=False) event: dict = field(default_factory=dict) - index_name: str = field(init=False) is_logged_in: bool = field(init=False) k: int = field(init=False) openai_api_version: str = field(init=False) @@ -54,7 +56,7 @@ class EventConfig: temperature: float = field(init=False) socket: Websocket = field(init=False, default=None) text_key: str = field(init=False) - + def __post_init__(self): self.payload = json.loads(self.event.get("body", "{}")) self.api_token = ApiToken(signed_token=self.payload.get("auth")) @@ -64,7 +66,6 @@ def __post_init__(self): self.azure_endpoint = self._get_azure_endpoint() self.debug_mode = self._is_debug_mode_enabled() self.deployment_name = self._get_deployment_name() - self.index_name = self._get_index_name() self.is_logged_in = self.api_token.is_logged_in() self.k = self._get_k() self.openai_api_version = self._get_openai_api_version() @@ -74,9 +75,10 @@ def __post_init__(self): self.ref = self.payload.get("ref") self.temperature = self._get_temperature() self.text_key = self._get_text_key() - self.attributes = self._get_attributes() self.document_prompt = self._get_document_prompt() - self.prompt = PromptTemplate(template=self.prompt_text, input_variables=["question", "context"]) + self.prompt = PromptTemplate( + template=self.prompt_text, input_variables=["question", "context"] + ) def _get_payload_value_with_superuser_check(self, key, default): if self.api_token.is_superuser(): @@ -84,65 +86,59 @@ def _get_payload_value_with_superuser_check(self, key, default): else: return default + def _get_attributes_function(self): + try: + opensearch = opensearch_client() + mapping = opensearch.indices.get_mapping(index="dc-v2-work") + return list(next(iter(mapping.values()))['mappings']['properties'].keys()) + except StopIteration: + return [] + + def _get_attributes(self): + return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES) + # return self._get_payload_value_with_superuser_check("attributes", self._get_attributes_function()) + def _get_azure_endpoint(self): default = f"https://{self._get_azure_resource_name()}.openai.azure.com/" return self._get_payload_value_with_superuser_check("azure_endpoint", default) def _get_azure_resource_name(self): - azure_resource_name = self._get_payload_value_with_superuser_check("azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME")) + azure_resource_name = self._get_payload_value_with_superuser_check( + "azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME") + ) if not azure_resource_name: raise EnvironmentError( "Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set" ) return azure_resource_name - + def _get_deployment_name(self): - return self._get_payload_value_with_superuser_check("deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")) - - def _get_index_name(self): - return self._get_payload_value_with_superuser_check("index", INDEX_NAME) + return self._get_payload_value_with_superuser_check( + "deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") + ) def _get_k(self): value = self._get_payload_value_with_superuser_check("k", K_VALUE) return min(value, MAX_K) def _get_openai_api_version(self): - return self._get_payload_value_with_superuser_check("openai_api_version", VERSION) - + return self._get_payload_value_with_superuser_check( + "openai_api_version", VERSION + ) + def _get_prompt_text(self): return self._get_payload_value_with_superuser_check("prompt", prompt_template()) - + def _get_temperature(self): return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE) def _get_text_key(self): return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY) - def _get_attributes(self): - attributes = [ - item - for item in self._get_request_attributes() - if item not in [self._get_text_key(), "source", "full_text"] - ] - return attributes - - def _get_request_attributes(self): - if os.getenv("SKIP_WEAVIATE_SETUP"): - return [] - - attributes = self._get_payload_value_with_superuser_check("attributes", []) - if attributes: - return attributes - else: - client = weaviate_client() - schema = client.schema.get(self._get_index_name()) - names = [prop["name"] for prop in schema.get("properties")] - return names - def _get_document_prompt(self): return PromptTemplate( template=document_template(self.attributes), - input_variables=["page_content", "source"] + self.attributes, + input_variables=["title", "id"] + self.attributes, ) def debug_message(self): @@ -152,7 +148,6 @@ def debug_message(self): "attributes": self.attributes, "azure_endpoint": self.azure_endpoint, "deployment_name": self.deployment_name, - "index": self.index_name, "k": self.k, "openai_api_version": self.openai_api_version, "prompt": self.prompt_text, @@ -167,7 +162,9 @@ def setup_websocket(self, socket=None): if socket is None: connection_id = self.request_context.get("connectionId") endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}' - self.socket = Websocket(endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref) + self.socket = Websocket( + endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref + ) else: self.socket = socket return self.socket @@ -178,11 +175,7 @@ def setup_llm_request(self): self._setup_chain() def _setup_vector_store(self): - self.weaviate = weaviate_vector_store( - index_name=self.index_name, - text_key=self.text_key, - attributes=self.attributes + ["source"], - ) + self.opensearch = opensearch_vector_store() def _setup_chat_client(self): self.client = openai_chat_client( diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index aa19ff79..8757b286 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -1,4 +1,6 @@ import os +import sys +import traceback from event_config import EventConfig from helpers.response import prepare_response @@ -21,9 +23,8 @@ def handler(event, _context): config.socket.send(final_response) return {"statusCode": 200} - except Exception as err: - if err.__class__.__name__ == "PayloadTooLargeException": - config.socket.send({"type": "error", "message": "Payload too large"}) - return {"statusCode": 413, "body": "Payload too large"} - else: - raise err + except Exception: + exc_info = sys.exc_info() + err_text = ''.join(traceback.format_exception(*exc_info)) + print(err_text) + return {"statusCode": 500, "body": f'Unhandled error:\n{err_text}'} diff --git a/chat/src/helpers/prompts.py b/chat/src/helpers/prompts.py index 32ffbc46..397b7005 100644 --- a/chat/src/helpers/prompts.py +++ b/chat/src/helpers/prompts.py @@ -16,8 +16,8 @@ def document_template(attributes: Optional[List[str]] = None) -> str: if attributes is None: attributes = [] lines = ( - ["Content: {page_content}", "Metadata:"] + ["Content: {title}", "Metadata:"] + [f" {attribute}: {{{attribute}}}" for attribute in attributes] - + ["Source: {source}"] + + ["Source: {id}"] ) return "\n".join(lines) diff --git a/chat/src/helpers/response.py b/chat/src/helpers/response.py index 42b4e4ed..a3b946d4 100644 --- a/chat/src/helpers/response.py +++ b/chat/src/helpers/response.py @@ -1,7 +1,6 @@ from helpers.metrics import token_usage from openai.error import InvalidRequestError - def base_response(config, response): return {"answer": response["output_text"], "ref": config.ref} @@ -12,7 +11,6 @@ def debug_response(config, response, original_question): "attributes": config.attributes, "azure_endpoint": config.azure_endpoint, "deployment_name": config.deployment_name, - "index": config.index_name, "is_superuser": config.api_token.is_superuser(), "k": config.k, "openai_api_version": config.openai_api_version, @@ -26,7 +24,13 @@ def debug_response(config, response, original_question): def get_and_send_original_question(config, docs): - doc_response = [doc.__dict__ for doc in docs] + doc_response = [] + for doc in docs: + doc_dict = doc.__dict__ + metadata = doc_dict.get('metadata', {}) + new_doc = {key: extract_prompt_value(metadata.get(key)) for key in config.attributes if key in metadata} + doc_response.append(new_doc) + original_question = { "question": config.question, "source_documents": doc_response, @@ -34,11 +38,18 @@ def get_and_send_original_question(config, docs): config.socket.send(original_question) return original_question - +def extract_prompt_value(v): + if isinstance(v, list): + return [extract_prompt_value(item) for item in v] + elif isinstance(v, dict) and 'label' in v: + return [v.get('label')] + else: + return v + def prepare_response(config): try: - docs = config.weaviate.similarity_search( - config.question, k=config.k, additional="certainty" + docs = config.opensearch.similarity_search( + config.question, k=config.k, vector_field="embedding", text_field="id" ) original_question = get_and_send_original_question(config, docs) response = config.chain({"question": config.question, "input_documents": docs}) diff --git a/chat/src/requirements.txt b/chat/src/requirements.txt index 8cb0270e..04100144 100644 --- a/chat/src/requirements.txt +++ b/chat/src/requirements.txt @@ -1,11 +1,14 @@ # Runtime Dependencies boto3~=1.34.13 -langchain~=0.0.208 +langchain~=0.1.8 +langchain-community openai~=0.27.8 +opensearch-py pyjwt~=2.6.0 python-dotenv~=1.0.0 +requests +requests-aws4auth tiktoken~=0.4.0 -weaviate-client~=3.19.2 wheel~=0.40.0 # Dev/Test Dependencies diff --git a/chat/src/setup.py b/chat/src/setup.py index cc70c653..0362020a 100644 --- a/chat/src/setup.py +++ b/chat/src/setup.py @@ -1,10 +1,16 @@ +from content_handler import ContentHandler from langchain.chat_models import AzureChatOpenAI -from langchain.vectorstores import Weaviate -from typing import List +from langchain_community.embeddings import SagemakerEndpointEmbeddings +from langchain_community.vectorstores import OpenSearchVectorSearch +from opensearchpy import OpenSearch, RequestsHttpConnection +from requests_aws4auth import AWS4Auth import os -import weaviate import boto3 +def prefix(value): + env_prefix = os.getenv("ENV_PREFIX") + env_prefix = None if env_prefix == "" else env_prefix + return '-'.join(filter(None, [env_prefix, value])) def openai_chat_client(**kwargs): return AzureChatOpenAI( @@ -12,42 +18,39 @@ def openai_chat_client(**kwargs): **kwargs, ) +def opensearch_client(region_name=os.getenv("AWS_REGION")): + print(region_name) + session = boto3.Session(region_name=region_name) + awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials()) + endpoint = os.getenv("ELASTICSEARCH_ENDPOINT") + + return OpenSearch( + hosts=[{'host': endpoint, 'port': 443}], + use_ssl = True, + connection_class=RequestsHttpConnection, + http_auth=awsauth, + ) -def weaviate_client(): - if os.getenv("SKIP_WEAVIATE_SETUP"): - return None - - weaviate_url = os.environ.get("WEAVIATE_URL") - try: - if weaviate_url is None: - raise EnvironmentError( - "WEAVIATE_URL is not set in the environment variables" - ) - - weaviate_api_key = os.environ.get("WEAVIATE_API_KEY") - if weaviate_api_key is None: - raise EnvironmentError( - "WEAVIATE_API_KEY is not set in the environment variables" - ) - - auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key) - - client = weaviate.Client(url=weaviate_url, auth_client_secret=auth_config) - except Exception as e: - print(f"An error occurred: {e}") - client = None - return client - +def opensearch_vector_store(region_name=os.getenv("AWS_REGION")): + session = boto3.Session(region_name=region_name) + awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials()) -def weaviate_vector_store(index_name: str, text_key: str, attributes: List[str] = []): - if os.getenv("SKIP_WEAVIATE_SETUP"): - return None - - client = weaviate_client() + sagemaker_client = session.client(service_name="sagemaker-runtime", region_name=session.region_name) + embeddings = SagemakerEndpointEmbeddings( + client=sagemaker_client, + region_name=session.region_name, + endpoint_name=os.getenv("EMBEDDING_ENDPOINT"), + content_handler=ContentHandler() + ) - return Weaviate( - client=client, index_name=index_name, text_key=text_key, attributes=attributes + docsearch = OpenSearchVectorSearch( + index_name=prefix("dc-v2-work"), + embedding_function=embeddings, + opensearch_url="https://" + os.getenv("ELASTICSEARCH_ENDPOINT"), + connection_class=RequestsHttpConnection, + http_auth=awsauth, ) + return docsearch def websocket_client(endpoint_url: str): diff --git a/chat/src/websocket.py b/chat/src/websocket.py index dc81179a..ea682b0a 100644 --- a/chat/src/websocket.py +++ b/chat/src/websocket.py @@ -12,5 +12,9 @@ def send(self, data): data = {"message": data} data["ref"] = self.ref data_as_bytes = bytes(json.dumps(data), "utf-8") - self.client.post_to_connection(Data=data_as_bytes, ConnectionId=self.connection_id) + + if self.connection_id == "debug": + print(data) + else: + self.client.post_to_connection(Data=data_as_bytes, ConnectionId=self.connection_id) return data diff --git a/chat/template.yaml b/chat/template.yaml index 24c89b7d..d7696246 100644 --- a/chat/template.yaml +++ b/chat/template.yaml @@ -17,12 +17,12 @@ Parameters: AzureOpenaiResourceName: Type: String Description: Azure OpenAI Resource Name - WeaviateApiKey: + ElasticsearchEndpoint: Type: String - Description: Weaviate API Key - WeaviateUrl: + Description: Elasticsearch URL + EmbeddingEndpoint: Type: String - Description: Weaviate URL + Description: Sagemaker Inference Endpoint Resources: ApiGwAccountConfig: Type: "AWS::ApiGateway::Account" @@ -202,8 +202,8 @@ Resources: AZURE_OPENAI_EMBEDDING_DEPLOYMENT_ID: !Ref AzureOpenaiEmbeddingDeploymentId AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName - WEAVIATE_API_KEY: !Ref WeaviateApiKey - WEAVIATE_URL: !Ref WeaviateUrl + ELASTICSEARCH_ENDPOINT: !Ref ElasticsearchEndpoint + EMBEDDING_ENDPOINT: !Ref EmbeddingEndpoint Policies: - Statement: - Effect: Allow @@ -211,6 +211,18 @@ Resources: - 'execute-api:ManageConnections' Resource: - !Sub 'arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${ChatWebSocket}/*' + - Statement: + - Effect: Allow + Action: + - 'es:ESHttpGet' + - 'es:ESHttpPost' + Resource: '*' + - Statement: + - Effect: Allow + Action: + - 'sagemaker:InvokeEndpoint' + - 'sagemaker:InvokeEndpointAsync' + Resource: !Sub 'arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:endpoint/${EmbeddingEndpoint}' Metadata: BuildMethod: nodejs18.x Deployment: diff --git a/chat/test/handlers/test_chat.py b/chat/test/handlers/test_chat.py index 21c9b643..6a0f32c9 100644 --- a/chat/test/handlers/test_chat.py +++ b/chat/test/handlers/test_chat.py @@ -70,5 +70,8 @@ def test_handler_debug_mode_for_superusers_only(self, mock_is_debug_enabled, moc @patch.object(EventConfig, 'setup_websocket') def test_error_handling(self, mock_event): mock_event.side_effect = Exception("Some error occurred") - with self.assertRaises(Exception): - handler({}, {}) \ No newline at end of file + response = handler({}, {}) + self.assertEqual(response['statusCode'], 500) + body_lines = response['body'].strip().split('\n') + self.assertEquals(body_lines[0], 'Unhandled error:') + self.assertEquals(body_lines[-1], 'Exception: Some error occurred') diff --git a/chat/test/helpers/test_prompts.py b/chat/test/helpers/test_prompts.py index 9508f32a..b9a7d950 100644 --- a/chat/test/helpers/test_prompts.py +++ b/chat/test/helpers/test_prompts.py @@ -17,17 +17,17 @@ class TestDocumentTemplate(TestCase): def test_empty_attributes(self): self.assertEqual( document_template(), - "Content: {page_content}\nMetadata:\nSource: {source}", + "Content: {title}\nMetadata:\nSource: {id}", ) def test_single_attribute(self): self.assertEqual( document_template(["title"]), - "Content: {page_content}\nMetadata:\n title: {title}\nSource: {source}", + "Content: {title}\nMetadata:\n title: {title}\nSource: {id}", ) def test_multiple_attributes(self): self.assertEqual( document_template(["title", "author", "subject", "description"]), - "Content: {page_content}\nMetadata:\n title: {title}\n author: {author}\n subject: {subject}\n description: {description}\nSource: {source}", + "Content: {title}\nMetadata:\n title: {title}\n author: {author}\n subject: {subject}\n description: {description}\nSource: {id}", ) diff --git a/chat/test/test_event_config.py b/chat/test/test_event_config.py index 8d8c02c1..55f8381d 100644 --- a/chat/test/test_event_config.py +++ b/chat/test/test_event_config.py @@ -50,10 +50,9 @@ def test_attempt_override_without_superuser_status(self): } ) expected_output = { - "attributes": [], + "attributes": EventConfig.DEFAULT_ATTRIBUTES, "azure_endpoint": "https://test.openai.azure.com/", - "index_name": "DCWork", - "k": 10, + "k": 5, "openai_api_version": "2023-07-01-preview", "question": "test question", "ref": "test ref", @@ -61,7 +60,6 @@ def test_attempt_override_without_superuser_status(self): "text_key": "title", } self.assertEqual(actual.azure_endpoint, expected_output["azure_endpoint"]) - self.assertEqual(actual.index_name, expected_output["index_name"]) self.assertEqual(actual.attributes, expected_output["attributes"]) self.assertEqual(actual.k, expected_output["k"]) self.assertEqual( @@ -72,23 +70,6 @@ def test_attempt_override_without_superuser_status(self): self.assertEqual(actual.temperature, expected_output["temperature"]) self.assertEqual(actual.text_key, expected_output["text_key"]) - def test_text_key_removed_from_attributes_list(self): - actual = EventConfig( - event={ - "body": json.dumps( - { - "attributes": ["title", "description"], - "text_key": "description", - } - ) - } - ) - self.assertNotIn(actual.text_key, actual.attributes) - - def test_source_removed_from_attributes_list(self): - actual = EventConfig(event={"body": json.dumps({"attributes": ["source"]})}) - self.assertNotIn("source", actual.attributes) - def test_debug_message(self): self.assertEqual( EventConfig( diff --git a/template.yaml b/template.yaml index 3256c546..5234a30a 100644 --- a/template.yaml +++ b/template.yaml @@ -62,6 +62,9 @@ Parameters: ElasticsearchEndpoint: Type: String Description: Elasticsearch url + EmbeddingEndpoint: + Type: String + Description: Sagemaker Inference Endpoint EnvironmentPrefix: Type: String Description: Index Prefix @@ -112,12 +115,6 @@ Parameters: StreamingBucket: Type: String Description: Meadow streaming bucket - WeaviateApiKey: - Type: String - Description: Weaviate API Key - WeaviateUrl: - Type: String - Description: Weaviate URL Resources: apiDependencies: Type: AWS::Serverless::LayerVersion @@ -662,8 +659,8 @@ Resources: AzureOpenaiEmbeddingDeploymentId: !Ref AzureOpenaiEmbeddingDeploymentId AzureOpenaiLlmDeploymentId: !Ref AzureOpenaiLlmDeploymentId AzureOpenaiResourceName: !Ref AzureOpenaiResourceName - WeaviateApiKey: !Ref WeaviateApiKey - WeaviateUrl: !Ref WeaviateUrl + ElasticsearchEndpoint: !Ref ElasticsearchEndpoint + EmbeddingEndpoint: !Ref EmbeddingEndpoint chatWebsocketEndpoint: Type: AWS::Serverless::Function Properties: