Skip to content

Commit

Permalink
Big configuration refactor, adds overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Nov 17, 2023
1 parent f59f6a2 commit b18681f
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 247 deletions.
212 changes: 9 additions & 203 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,6 @@
import json
import os

import boto3
import tiktoken
from helpers.apitoken import ApiToken
from helpers.prompts import document_template, prompt_template
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts import PromptTemplate
from openai.error import InvalidRequestError
from setup import (
weaviate_client,
weaviate_vector_store,
openai_chat_client,
)


DEFAULT_INDEX = "DCWork"
DEFAULT_KEY = "title"
DEFAULT_K = 10
MAX_K = 100


class Websocket:
def __init__(self, endpoint_url, connection_id, ref):
self.client = boto3.client("apigatewaymanagementapi", endpoint_url=endpoint_url)
self.connection_id = connection_id
self.ref = ref

def send(self, data):
data["ref"] = self.ref
data_as_bytes = bytes(json.dumps(data), "utf-8")
self.client.post_to_connection(
Data=data_as_bytes, ConnectionId=self.connection_id
)


class StreamingSocketCallbackHandler(BaseCallbackHandler):
def __init__(self, socket: Websocket, debug_mode: bool):
self.socket = socket
self.debug_mode = debug_mode

def on_llm_new_token(self, token: str, **kwargs):
if not self.debug_mode:
self.socket.send({"token": token})


class EventConfig:
def __init__(self, event):
self.payload = json.loads(event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.debug_mode = self._get_debug_mode()
self.index_name = self.payload.get(
"index", self.payload.get("index", DEFAULT_INDEX)
)
self.is_logged_in = self.api_token.is_logged_in()
self.k = min(self.payload.get("k", DEFAULT_K), MAX_K)
self.prompt_text = (
self.payload.get("prompt", prompt_template())
if self.api_token.is_superuser()
else prompt_template()
)
self.request_context = event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.text_key = self.payload.get("text_key", DEFAULT_KEY)

self.attributes = [
item
for item in self._get_attributes()
if item not in [self.text_key, "source"]
]
self.document_prompt = PromptTemplate(
template=document_template(self.attributes),
input_variables=["page_content", "source"] + self.attributes,
)
self.prompt = PromptTemplate(
template=self.prompt_text, input_variables=["question", "context"]
)

def setup_websocket(self):
connection_id = self.request_context.get("connectionId")
endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}'
self.socket = Websocket(
connection_id=connection_id, endpoint_url=endpoint_url, ref=self.ref
)

def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()
self._setup_chain()

def _setup_vector_store(self):
self.weaviate = weaviate_vector_store(
index_name=self.index_name,
text_key=self.text_key,
attributes=self.attributes + ["source"],
)

def _setup_chat_client(self):
self.client = openai_chat_client(
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)],
streaming=True,
)

def _setup_chain(self):
self.chain = load_qa_with_sources_chain(
self.client,
chain_type="stuff",
prompt=self.prompt,
document_prompt=self.document_prompt,
document_variable_name="context",
verbose=to_bool(os.getenv("VERBOSE")),
)

def _get_debug_mode(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _get_attributes(self):
request_attributes = self.payload.get("attributes", None)
if request_attributes is not None:
return request_attributes

client = weaviate_client()
schema = client.schema.get(self.index_name)
names = [prop["name"] for prop in schema.get("properties")]
print(f"Retrieved attributes: {names}")
return names
import traceback
from handlers.event_config import EventConfig
from helpers.response import prepare_response


def handler(event, _context):
Expand All @@ -136,85 +9,18 @@ def handler(event, _context):
config.setup_websocket()

if not config.is_logged_in:
config.socket.send({"statusCode": 401, "body": "Unauthorized"})
config.socket.send({"type": "error", "message": "Unauthorized"})
return {"statusCode": 401, "body": "Unauthorized"}

if config.debug_mode:
config.socket.send(config.debug_message())

config.setup_llm_request()
final_response = prepare_response(config)
config.socket.send(final_response)
return {"statusCode": 200}
except Exception as err:
error_message = traceback.format_exc()
config.socket.send(error_message)
print(event)
raise err


def get_and_send_original_question(config, docs):
doc_response = [doc.__dict__ for doc in docs]
original_question = {
"question": config.question,
"source_documents": doc_response,
}
config.socket.send(original_question)
return original_question


def token_usage(config, response, original_question):
return {
"question": count_tokens(config.question),
"answer": count_tokens(response["output_text"]),
"prompt": count_tokens(config.prompt_text),
"source_documents": count_tokens(original_question["source_documents"])
}


def prepare_debug_response(config, response, original_question):
return {
"answer": response["output_text"],
"attributes": config.attributes,
"is_superuser": config.api_token.is_superuser(),
"prompt": config.prompt_text,
"ref": config.ref,
"k": config.k,
"original_question": original_question,
"token_counts": token_usage(config, response, original_question),
}


def prepare_normal_response(config, response):
return {"answer": response["output_text"], "ref": config.ref}


def prepare_response(config):
try:
docs = config.weaviate.similarity_search(
config.question, k=config.k, additional="certainty"
)
original_question = get_and_send_original_question(config, docs)
response = config.chain({"question": config.question, "input_documents": docs})
if config.debug_mode:
prepared_response = prepare_debug_response(
config, response, original_question
)
else:
prepared_response = prepare_normal_response(config, response)
except InvalidRequestError as err:
prepared_response = {
"question": config.question,
"error": str(err),
"source_documents": [],
}
return prepared_response


def count_tokens(val):
encoding = tiktoken.encoding_for_model("gpt-4")
token_integers = encoding.encode(str(val))
num_tokens = len(token_integers)

return num_tokens


def to_bool(val):
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
163 changes: 163 additions & 0 deletions chat/src/handlers/event_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os
import json

from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler
from helpers.apitoken import ApiToken
from helpers.prompts import document_template, prompt_template
from helpers.utils import to_bool
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts import PromptTemplate
from setup import (
weaviate_client,
weaviate_vector_store,
openai_chat_client,
)
from websocket import Websocket

CHAIN_TYPE = "stuff"
DOCUMENT_VARIABLE_NAME = "context"
INDEX_NAME = "DCWork"
K_VALUE = 10
TEXT_KEY = "title"
MAX_K = 100
TEMPERATURE = 0.2
VERSION = "2023-07-01-preview"


class EventConfig:
"""
The EventConfig class represents the configuration for an event. Default values are set for the following properties which can be overridden in the payload message:
- attributes: Document attributes sent to the LLM.
- auth: Authentication token for the connection.
- azure_endpoint: Full URI for the Azure OpenAI endpoint.
- debug: Debug mode status (requires a superuser token).
- deployment_name: Name of the Azure AI deployment.
- index_name: Name of the vector database index.
- k: The number of documents retreived in vector database searches.
- message: Type of socket communication.
- openai_api_version: Version of the Azure AI model.
- prompt_text*: System prompt (the string must contain both page_content and source: "{page_content} {source}").
- question: User prompt typically sent via frontend input.
- ref: Reference for uniquely identifying the request.
- text_key: Attribute used to name each document.
* requires debug mode to be enabled
"""

def __init__(self, event):
self.payload = json.loads(event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.azure_endpoint = self.payload.get(
"azure_endpoint",
f"https://{os.getenv('AZURE_OPENAI_RESOURCE_NAME')}.openai.azure.com/",
)
self.debug_mode = self._is_debug_mode_enabled()
self.deployment_name = self.payload.get(
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
)
self.index_name = self.payload.get(
"index", self.payload.get("index", INDEX_NAME)
)
self.is_logged_in = self.api_token.is_logged_in()
self.k = min(self.payload.get("k", K_VALUE), MAX_K)
self.openai_api_version = self.payload.get("openai_api_version", VERSION)
self.prompt_text = (
self.payload.get("prompt", prompt_template())
if self.api_token.is_superuser()
else prompt_template()
)
self.request_context = event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.temperature = self.payload.get("temperature", TEMPERATURE)
self.text_key = self.payload.get("text_key", TEXT_KEY)

self.attributes = [
item
for item in self._get_request_attributes()
if item not in [self.text_key, "source"]
]
self.document_prompt = PromptTemplate(
template=document_template(self.attributes),
input_variables=["page_content", "source"] + self.attributes,
)
self.prompt = PromptTemplate(
template=self.prompt_text, input_variables=["question", "context"]
)

def debug_message(self):
return {
"type": "debug",
"message": {
"azure_endpoint": self.azure_endpoint,
"deployment_name": self.deployment_name,
"index": self.index_name,
"k": self.k,
"openai_api_version": self.openai_api_version,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
"temperature": self.temperature,
"text_key": self.text_key,
},
}

def setup_websocket(self):
connection_id = self.request_context.get("connectionId")
endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}'
self.socket = Websocket(
connection_id=connection_id, endpoint_url=endpoint_url, ref=self.ref
)

def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()
self._setup_chain()

def _setup_vector_store(self):
self.weaviate = weaviate_vector_store(
index_name=self.index_name,
text_key=self.text_key,
attributes=self.attributes + ["source"],
)

def _setup_chat_client(self):
self.client = openai_chat_client(
deployment_name=self.deployment_name,
openai_api_base=self.azure_endpoint,
openai_api_version=self.openai_api_version,
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)],
streaming=True,
)

def _setup_chain(self):
self.chain = load_qa_with_sources_chain(
self.client,
chain_type=CHAIN_TYPE,
prompt=self.prompt,
document_prompt=self.document_prompt,
document_variable_name=DOCUMENT_VARIABLE_NAME,
verbose=self._to_bool(os.getenv("VERBOSE")),
)

def _is_debug_mode_enabled(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _get_request_attributes(self):
request_attributes = self.payload.get("attributes", None)
if request_attributes is not None:
return request_attributes

client = weaviate_client()
schema = client.schema.get(self.index_name)
names = [prop["name"] for prop in schema.get("properties")]
return names

def _to_bool(self, val):
"""Converts a value to boolean. If the value is a string, it considers
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value.
"""
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
Loading

0 comments on commit b18681f

Please sign in to comment.