From 9e84b576028b443f9b2bbd91d735669d294aecdb Mon Sep 17 00:00:00 2001 From: Henry Date: Fri, 30 Jun 2023 20:23:15 -0400 Subject: [PATCH] add metadata in similarity search --- app.py | 55 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/app.py b/app.py index cde7c608..e60c8c0a 100644 --- a/app.py +++ b/app.py @@ -10,30 +10,42 @@ from htmlTemplates import css, bot_template, user_template from langchain.llms import HuggingFaceHub -def get_pdf_text(pdf_docs): - text = "" - for pdf in pdf_docs: - pdf_reader = PdfReader(pdf) - for page in pdf_reader.pages: - text += page.extract_text() - return text +def get_pdf_text (pdf_docs) : + text_by_loaders_page = [] + metadatas = [] + for reader in pdf_docs: + pdf_reader = PdfReader(reader) + for j, page in enumerate (pdf_reader.pages) : + text = page.extract_text() + text_by_loaders_page. append (text) + metadatas. append ({'source': reader.name, 'page' : j+1}) + return text_by_loaders_page, metadatas + + +def get_text_chunks (text_by_loaders_page, metadatas) : -def get_text_chunks(text): text_splitter = CharacterTextSplitter( - separator="\n", - chunk_size=1000, - chunk_overlap=200, - length_function=len - ) - chunks = text_splitter.split_text(text) - return chunks + separator="\n", + chunk_size=1000, + chunk_overlap=200, + length_function=len) + + text_chunks = [] + metadata_input = [] + + for i, text in enumerate(text_by_loaders_page) : + + texts_temp = text_splitter.split_text(text) + metadata_input = [metadatas[i]]*len(texts_temp) + text_chunks += texts_temp + return text_chunks, metadata_input -def get_vectorstore(text_chunks): +def get_vectorstore(text_chunks,metadata_input): embeddings = OpenAIEmbeddings() # embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl") - vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings) + vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings,metadatas=metadata_input) return vectorstore @@ -79,6 +91,7 @@ def main(): user_question = st.text_input("Ask a question about your documents:") if user_question: handle_userinput(user_question) + st.write(st.session_state.vectorstore.similarity_search(user_question)) with st.sidebar: st.subheader("Your documents") @@ -87,17 +100,17 @@ def main(): if st.button("Process"): with st.spinner("Processing"): # get pdf text - raw_text = get_pdf_text(pdf_docs) + raw_text, metadatas = get_pdf_text(pdf_docs) # get the text chunks - text_chunks = get_text_chunks(raw_text) + text_chunks, metadata_input = get_text_chunks(raw_text, metadatas) # create vector store - vectorstore = get_vectorstore(text_chunks) + st.session_state.vectorstore = get_vectorstore(text_chunks,metadata_input) # create conversation chain st.session_state.conversation = get_conversation_chain( - vectorstore) + st.session_state.vectorstore) if __name__ == '__main__':