Skip to content

Commit

Permalink
Allow superuser to override prompt and attributes
Browse files Browse the repository at this point in the history
Allow any request to override index name and k value

Make sure chatWebsocketEndpoint uses the API dependency layer

Fix up attribute filtering

Adds debug mode to chat handler

Simplify the prompt template, now that we can override it in development easily

Remove full_text from LLM prompt
Bump chatFunction memory to 1GB

Allow for overriding parameters to the LLM with default configuration in place

- Large refactor of configuration handling, adds the ability to override many more parameters via websocket messages
- Tests passing in dev environment using the Makefile and Github actions
- Allow for skipping weaviate setup in Github actions via environment variable

Temporarily removes full_text from vector searches
  • Loading branch information
mbklein committed Feb 26, 2024
1 parent a93b095 commit b1246b6
Show file tree
Hide file tree
Showing 28 changed files with 884 additions and 391 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
paths:
- ".github/workflows/deploy.yml"
- "node/**"
- "python/**"
- "chat/**"
- "template.yaml"
workflow_dispatch:
concurrency:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
env:
AWS_ACCESS_KEY_ID: ci
AWS_SECRET_ACCESS_KEY: ci
SKIP_WEAVIATE_SETUP: 'True'
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ lerna-debug.log*

### Python ###
.coverage
htmlcov
__pycache__/
*.py[cod]
*$py.class
Expand Down
120 changes: 62 additions & 58 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,58 +1,62 @@
ifndef VERBOSE
.SILENT:
endif
ENV=dev

help:
echo "make build | build the SAM project"
echo "make serve | run the SAM server locally"
echo "make clean | remove all installed dependencies and build artifacts"
echo "make deps | install all dependencies"
echo "make link | create hard links to allow for hot reloading of a built project"
echo "make secrets | symlink secrets files from ../tfvars"
echo "make style | run all style checks"
echo "make test | run all tests"
echo "make cover | run all tests with coverage"
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)"
echo "make deps-node | install node dependencies"
echo "make deps-python | install python dependencies"
echo "make style-node | run node code style check"
echo "make style-python | run python code style check"
echo "make test-node | run node tests"
echo "make test-python | run python tests"
echo "make cover-node | run node tests with coverage"
echo "make cover-python | run python tests with coverage"
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json python/requirements.txt python/src/requirements.txt
sam build --cached --parallel
deps-node:
cd node && npm ci
cover-node:
cd node && npm run test:coverage
style-node:
cd node && npm run prettier
test-node:
cd node && npm run test
deps-python:
cd chat/src && pip install -r requirements.txt
cover-python:
cd chat/src && coverage run --include='src/**/*' -m unittest -v && coverage report
style-python:
cd chat && ruff check .
test-python:
cd chat && python -m unittest -v
build: .aws-sam/build.toml
link: build
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
serve: link
sam local start-api --host 0.0.0.0 --log-file dc-api.log
deps: deps-node deps-python
style: style-node style-python
test: test-node test-python
cover: cover-node cover-python
env:
ln -fs ./env.${ENV}.json ./env.json
secrets:
ln -s ../tfvars/dc-api/* .
clean:
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache
ifndef VERBOSE
.SILENT:
endif
ENV=dev

help:
echo "make build | build the SAM project"
echo "make serve | run the SAM server locally"
echo "make clean | remove all installed dependencies and build artifacts"
echo "make deps | install all dependencies"
echo "make link | create hard links to allow for hot reloading of a built project"
echo "make secrets | symlink secrets files from ../tfvars"
echo "make style | run all style checks"
echo "make test | run all tests"
echo "make cover | run all tests with coverage"
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)"
echo "make deps-node | install node dependencies"
echo "make deps-python | install python dependencies"
echo "make style-node | run node code style check"
echo "make style-python | run python code style check"
echo "make test-node | run node tests"
echo "make test-python | run python tests"
echo "make cover-node | run node tests with coverage"
echo "make cover-python | run python tests with coverage"
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/dependencies/requirements.txt chat/src/requirements.txt
sam build --cached --parallel
deps-node:
cd node && npm ci
cover-node:
cd node && npm run test:coverage
style-node:
cd node && npm run prettier
test-node:
cd node && npm run test
deps-python:
cd chat/src && pip install -r requirements.txt
cover-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage report --skip-empty
cover-html-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage html --skip-empty
style-python: deps-python
cd chat && ruff check .
test-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && PYTHONPATH=src:test && python -m unittest discover -v
python-version:
cd chat && python --version
build: .aws-sam/build.toml
link: build
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
serve: link
sam local start-api --host 0.0.0.0 --log-file dc-api.log
deps: deps-node deps-python
style: style-node style-python
test: test-node test-python
cover: cover-node cover-python
env:
ln -fs ./env.${ENV}.json ./env.json
secrets:
ln -s ../tfvars/dc-api/* .
clean:
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache
3 changes: 1 addition & 2 deletions chat/dependencies/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
boto3~=1.34.13
langchain~=0.0.208
nbformat~=5.9.0
openai~=0.27.8
pandas~=2.0.2
pyjwt~=2.6.0
python-dotenv~=1.0.0
tiktoken~=0.4.0
Expand Down
216 changes: 216 additions & 0 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import os
import json

from dataclasses import dataclass, field
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts import PromptTemplate
from setup import (
weaviate_client,
weaviate_vector_store,
openai_chat_client,
)
from typing import List
from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler
from helpers.apitoken import ApiToken
from helpers.prompts import document_template, prompt_template
from websocket import Websocket


CHAIN_TYPE = "stuff"
DOCUMENT_VARIABLE_NAME = "context"
INDEX_NAME = "DCWork"
K_VALUE = 10
MAX_K = 100
TEMPERATURE = 0.2
TEXT_KEY = "title"
VERSION = "2023-07-01-preview"


@dataclass
class EventConfig:
"""
The EventConfig class represents the configuration for an event.
Default values are set for the following properties which can be overridden in the payload message.
"""

api_token: ApiToken = field(init=False)
attributes: List[str] = field(init=False)
azure_endpoint: str = field(init=False)
azure_resource_name: str = field(init=False)
debug_mode: bool = field(init=False)
deployment_name: str = field(init=False)
document_prompt: PromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
index_name: str = field(init=False)
is_logged_in: bool = field(init=False)
k: int = field(init=False)
openai_api_version: str = field(init=False)
payload: dict = field(default_factory=dict)
prompt_text: str = field(init=False)
prompt: PromptTemplate = field(init=False)
question: str = field(init=False)
ref: str = field(init=False)
request_context: dict = field(init=False)
temperature: float = field(init=False)
socket: Websocket = field(init=False, default=None)
text_key: str = field(init=False)

def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.attributes = self._get_attributes()
self.azure_endpoint = self._get_azure_endpoint()
self.azure_resource_name = self._get_azure_resource_name()
self.azure_endpoint = self._get_azure_endpoint()
self.debug_mode = self._is_debug_mode_enabled()
self.deployment_name = self._get_deployment_name()
self.index_name = self._get_index_name()
self.is_logged_in = self.api_token.is_logged_in()
self.k = self._get_k()
self.openai_api_version = self._get_openai_api_version()
self.prompt_text = self._get_prompt_text()
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.temperature = self._get_temperature()
self.text_key = self._get_text_key()
self.attributes = self._get_attributes()
self.document_prompt = self._get_document_prompt()
self.prompt = PromptTemplate(template=self.prompt_text, input_variables=["question", "context"])

def _get_payload_value_with_superuser_check(self, key, default):
if self.api_token.is_superuser():
return self.payload.get(key, default)
else:
return default

def _get_azure_endpoint(self):
default = f"https://{self._get_azure_resource_name()}.openai.azure.com/"
return self._get_payload_value_with_superuser_check("azure_endpoint", default)

def _get_azure_resource_name(self):
azure_resource_name = self._get_payload_value_with_superuser_check("azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME"))
if not azure_resource_name:
raise EnvironmentError(
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set"
)
return azure_resource_name

def _get_deployment_name(self):
return self._get_payload_value_with_superuser_check("deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID"))

def _get_index_name(self):
return self._get_payload_value_with_superuser_check("index", INDEX_NAME)

def _get_k(self):
value = self._get_payload_value_with_superuser_check("k", K_VALUE)
return min(value, MAX_K)

def _get_openai_api_version(self):
return self._get_payload_value_with_superuser_check("openai_api_version", VERSION)

def _get_prompt_text(self):
return self._get_payload_value_with_superuser_check("prompt", prompt_template())

def _get_temperature(self):
return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE)

def _get_text_key(self):
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY)

def _get_attributes(self):
attributes = [
item
for item in self._get_request_attributes()
if item not in [self._get_text_key(), "source", "full_text"]
]
return attributes

def _get_request_attributes(self):
if os.getenv("SKIP_WEAVIATE_SETUP"):
return []

attributes = self._get_payload_value_with_superuser_check("attributes", [])
if attributes:
return attributes
else:
client = weaviate_client()
schema = client.schema.get(self._get_index_name())
names = [prop["name"] for prop in schema.get("properties")]
return names

def _get_document_prompt(self):
return PromptTemplate(
template=document_template(self.attributes),
input_variables=["page_content", "source"] + self.attributes,
)

def debug_message(self):
return {
"type": "debug",
"message": {
"attributes": self.attributes,
"azure_endpoint": self.azure_endpoint,
"deployment_name": self.deployment_name,
"index": self.index_name,
"k": self.k,
"openai_api_version": self.openai_api_version,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
"temperature": self.temperature,
"text_key": self.text_key,
},
}

def setup_websocket(self, socket=None):
if socket is None:
connection_id = self.request_context.get("connectionId")
endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}'
self.socket = Websocket(endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref)
else:
self.socket = socket
return self.socket

def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()
self._setup_chain()

def _setup_vector_store(self):
self.weaviate = weaviate_vector_store(
index_name=self.index_name,
text_key=self.text_key,
attributes=self.attributes + ["source"],
)

def _setup_chat_client(self):
self.client = openai_chat_client(
deployment_name=self.deployment_name,
openai_api_base=self.azure_endpoint,
openai_api_version=self.openai_api_version,
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)],
streaming=True,
)

def _setup_chain(self):
self.chain = load_qa_with_sources_chain(
self.client,
chain_type=CHAIN_TYPE,
prompt=self.prompt,
document_prompt=self.document_prompt,
document_variable_name=DOCUMENT_VARIABLE_NAME,
verbose=self._to_bool(os.getenv("VERBOSE")),
)

def _is_debug_mode_enabled(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _to_bool(self, val):
"""Converts a value to boolean. If the value is a string, it considers
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value.
"""
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
Empty file added chat/src/handlers/__init__.py
Empty file.
Loading

0 comments on commit b1246b6

Please sign in to comment.