-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathChatBot.py
130 lines (99 loc) · 5.18 KB
/
ChatBot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import utils
import openai
from dotenv import load_dotenv
import os
from langchain.embeddings import OpenAIEmbeddings
import streamlit as st
from utilities.sidebar import sidebar
from streaming import StreamHandler
import uuid
from langchain.callbacks.base import BaseCallbackHandler
# Import required libraries for different functionalities
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
st.set_page_config(
page_title='Ask Docs AI',
page_icon='🤖',
layout='wide',
initial_sidebar_state='expanded'
)
st.title("Ask Docs AI 🤖")
if "session_chat_history" not in st.session_state:
st.session_state.session_chat_history = []
# Load environment variables from .env file
load_dotenv()
embeddings = OpenAIEmbeddings()
class PrintRetrievalHandler(BaseCallbackHandler):
def __init__(self, container):
self.status = container.status("**Context Retrieval**")
def on_retriever_start(self, serialized: dict, query: str, **kwargs):
self.status.write(f"**Question:** {query}")
self.status.update(label=f"**Context Retrieval:** {query}")
def on_retriever_end(self, documents, **kwargs):
for idx, doc in enumerate(documents):
# source = os.path.basename(doc.metadata["path"])
self.status.write(f"**Document: {idx}**")
self.status.markdown(doc.page_content)
self.status.update(state="complete")
class CustomDataChatbot:
def __init__(self):
if "chat_messages" not in st.session_state:
st.session_state.chat_messages = []
openai.api_key = os.getenv("OPENAI_API_KEY")
def create_qa_chain(self):
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, politely say that you don't know, don't try to make up an answer. Always end your answer by asking the user if he needs more help.
---------------------------
{context}
---------------------------
Question: {question}
Friendly Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": PROMPT}
# Create a FAISS vector store using an existing Knowledgebase and OpenAI embeddings
vectorstore = st.session_state.Knowledgebase
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
llm = ChatOpenAI(temperature=0)
retriever_from_llm = MultiQueryRetriever.from_llm(
retriever=vectorstore.as_retriever(search_kwargs={"k": 6}), llm=llm
)
# compressor = LLMChainExtractor.from_llm(llm)
# compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever_from_llm)
return RetrievalQA.from_chain_type(llm=ChatOpenAI(streaming=True), chain_type="stuff", retriever=retriever_from_llm, return_source_documents=True,chain_type_kwargs=chain_type_kwargs)
@utils.enable_chat_history
def main(self):
user_query = st.chat_input(placeholder="Ask me anything!")
if user_query:
utils.display_msg(user_query, 'user')
with st.chat_message("assistant", avatar="https://e7.pngegg.com/pngimages/139/563/png-clipart-virtual-assistant-computer-icons-business-assistant-face-service-thumbnail.png"):
retrieval_handler = PrintRetrievalHandler(st.container())
st_callback = StreamHandler(st.empty())
qa = self.create_qa_chain()
result = qa({"query": user_query}, callbacks=[retrieval_handler,st_callback])
with st.expander("See sources"):
for doc in result['source_documents']:
st.success(f"Filename: {doc.metadata['source']}")
st.info(f"\nPage Content: {doc.page_content}")
st.json(doc.metadata, expanded=False)
if "is_Knowledgebase_loaded" in st.session_state and st.session_state.is_Knowledgebase_loaded:
pass
else:
st.download_button("Download Original File", st.session_state.files_for_download[f"{doc.metadata['source']}"], file_name=doc.metadata[
"source"], mime="application/octet-stream", key=uuid.uuid4(), use_container_width=True)
st.session_state.messages.append(
{"role": "assistant", "content": result['result'], "matching_docs": result['source_documents']})
st.session_state.session_chat_history.append(
(user_query, result["result"]))
if __name__ == "__main__":
if "Knowledgebase" in st.session_state:
obj = CustomDataChatbot()
obj.main()
sidebar()
else:
st.warning("Please create a Knowledgebase first!")