Skip to content

Commit

Permalink
Fix up attribute filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Oct 17, 2023
1 parent 72de7a0 commit 92116e3
Showing 1 changed file with 112 additions and 100 deletions.
212 changes: 112 additions & 100 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,114 +14,126 @@
DEFAULT_K = 10
MAX_K = 100


class Websocket:
def __init__(self, endpoint_url, connection_id, ref):
self.client = boto3.client('apigatewaymanagementapi', endpoint_url=endpoint_url)
self.connection_id = connection_id
self.ref = ref
def __init__(self, endpoint_url, connection_id, ref):
self.client = boto3.client("apigatewaymanagementapi", endpoint_url=endpoint_url)
self.connection_id = connection_id
self.ref = ref

def send(self, data):
data["ref"] = self.ref
data_as_bytes = bytes(json.dumps(data), "utf-8")
self.client.post_to_connection(
Data=data_as_bytes, ConnectionId=self.connection_id
)

def send(self, data):
data['ref'] = self.ref
data_as_bytes = bytes(json.dumps(data), 'utf-8')
self.client.post_to_connection(Data=data_as_bytes, ConnectionId=self.connection_id)

class StreamingSocketCallbackHandler(BaseCallbackHandler):
def __init__(self, socket: Websocket):
self.socket = socket

def on_llm_new_token(self, token: str, **kwargs):
self.socket.send({'token': token})
def __init__(self, socket: Websocket):
self.socket = socket

def on_llm_new_token(self, token: str, **kwargs):
self.socket.send({"token": token})


def handler(event, context):
try:
payload = json.loads(event.get('body', '{}'))

request_context = event.get('requestContext', {})
connection_id = request_context.get('connectionId')
endpoint_url = f'https://{request_context.get("domainName")}/{request_context.get("stage")}'
ref = payload.get('ref')
socket = Websocket(connection_id=connection_id, endpoint_url=endpoint_url, ref=ref)


api_token = ApiToken(signed_token=payload.get("auth"))
if not api_token.is_logged_in():
socket.send({ "statusCode": 401, "body": "Unauthorized" })
return {
"statusCode": 401,
"body": "Unauthorized"
}

question = payload.get("question")
index_name = payload.get("index", payload.get('index', DEFAULT_INDEX))
print(f'Searching index {index_name}')
text_key = payload.get("text_key", DEFAULT_KEY)
attributes = [
item for item
in get_attributes(index_name, payload if api_token.is_superuser() else {})
if item not in [text_key, "source"]
]

weaviate = setup.weaviate_vector_store(index_name=index_name,
text_key=text_key,
attributes=attributes + ["source"])

client = setup.openai_chat_client(callbacks=[StreamingSocketCallbackHandler(socket)], streaming=True)

prompt_text = payload.get("prompt", prompt_template()) if api_token.is_superuser() else prompt_template()
prompt = PromptTemplate(
template=prompt_text,
input_variables=["question", "context"]
)

document_prompt = PromptTemplate(
template=document_template(attributes),
input_variables=["page_content", "source"] + attributes,
)

k = min(payload.get("k", DEFAULT_K), MAX_K)
docs = weaviate.similarity_search(question, k=k, additional="certainty")
chain = load_qa_with_sources_chain(
client,
chain_type="stuff",
prompt=prompt,
document_prompt=document_prompt,
document_variable_name="context",
verbose=to_bool(os.getenv("VERBOSE"))
)

try:
doc_response = [doc.__dict__ for doc in docs]
socket.send({"question": question, "source_documents": doc_response})
response = chain({"question": question, "input_documents": docs})
response = {
"answer": response["output_text"],
}
socket.send(response)
except InvalidRequestError as err:
response = {
"question": question,
"answer": str(err),
"source_documents": []
}
socket.send(response)

return {'statusCode': 200}
except Exception as err:
print(event)
raise err
payload = json.loads(event.get("body", "{}"))

request_context = event.get("requestContext", {})
connection_id = request_context.get("connectionId")
endpoint_url = f'https://{request_context.get("domainName")}/{request_context.get("stage")}'
ref = payload.get("ref")
socket = Websocket(
connection_id=connection_id, endpoint_url=endpoint_url, ref=ref
)

api_token = ApiToken(signed_token=payload.get("auth"))
if not api_token.is_logged_in():
socket.send({"statusCode": 401, "body": "Unauthorized"})
return {"statusCode": 401, "body": "Unauthorized"}

question = payload.get("question")
index_name = payload.get("index", payload.get("index", DEFAULT_INDEX))
print(f"Searching index {index_name}")
text_key = payload.get("text_key", DEFAULT_KEY)
attributes = [
item
for item in get_attributes(
index_name, payload if api_token.is_superuser() else {}
)
if item not in [text_key, "source"]
]

weaviate = setup.weaviate_vector_store(
index_name=index_name, text_key=text_key, attributes=attributes + ["source"]
)

client = setup.openai_chat_client(
callbacks=[StreamingSocketCallbackHandler(socket)], streaming=True
)

prompt_text = (
payload.get("prompt", prompt_template())
if api_token.is_superuser()
else prompt_template()
)
prompt = PromptTemplate(
template=prompt_text, input_variables=["question", "context"]
)

document_prompt = PromptTemplate(
template=document_template(attributes),
input_variables=["page_content", "source"] + attributes,
)

k = min(payload.get("k", DEFAULT_K), MAX_K)
docs = weaviate.similarity_search(question, k=k, additional="certainty")
chain = load_qa_with_sources_chain(
client,
chain_type="stuff",
prompt=prompt,
document_prompt=document_prompt,
document_variable_name="context",
verbose=to_bool(os.getenv("VERBOSE")),
)

try:
doc_response = [doc.__dict__ for doc in docs]
socket.send({"question": question, "source_documents": doc_response})
response = chain({"question": question, "input_documents": docs})
response = {
"answer": response["output_text"],
}
socket.send(response)
except InvalidRequestError as err:
response = {
"question": question,
"answer": str(err),
"source_documents": [],
}
socket.send(response)

return {"statusCode": 200}
except Exception as err:
print(event)
raise err


def get_attributes(index, payload):
request_attributes = payload.get('attributes', None)
if request_attributes is not None:
return ','.split(request_attributes)

client = setup.weaviate_client()
schema = client.schema.get(index)
names = [prop['name'] for prop in schema.get('properties')]
print(f'Retrieved attributes: {names}')
return names
request_attributes = payload.get("attributes", None)
if request_attributes is not None:
return request_attributes

client = setup.weaviate_client()
schema = client.schema.get(index)
names = [prop["name"] for prop in schema.get("properties")]
print(f"Retrieved attributes: {names}")
return names


def to_bool(val):
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)

0 comments on commit 92116e3

Please sign in to comment.