Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possibilities to Enhance Vector Store Retrieval to Minimize Token Usage #29

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions rag_demo/vector_chain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from json import loads, dumps
from langchain.prompts.prompt import PromptTemplate

from langchain_community.vectorstores import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chains.conversation.memory import ConversationBufferMemory
Expand All @@ -11,7 +10,7 @@

VECTOR_PROMPT_TEMPLATE = """Human: You are a Financial expert with SEC filings who can answer questions only based on the context below.
* Answer the question STRICTLY based on the context provided in JSON below.
* Do not assume or retrieve any information outside of the context
* Do not assume or retrieve any information outside of the context
* Use three sentences maximum and keep the answer concise
* Think step by step before answering.
* Do not return helpful or extra text or apologies
Expand Down Expand Up @@ -52,10 +51,10 @@
username = st.secrets["NEO4J_USERNAME"]
password = st.secrets["NEO4J_PASSWORD"]


vector_store = None
try:
logging.debug(f"Attempting to retrieve existing vector index: {index_name}...")
logging.debug(
f"Attempting to retrieve existing vector index: {index_name}...")
vector_store = Neo4jVector.from_existing_index(
embedding=EMBEDDING_MODEL,
url=url,
Expand All @@ -82,7 +81,8 @@
)
logging.debug(f"Created new index: {index_name}")
except Exception as e:
logging.error(f"Failed to retrieve existing or to create a Neo4jVector: {e}")
logging.error(
f"Failed to retrieve existing or to create a Neo4jVector: {e}")

if vector_store is None:
logging.error(f"Failed to retrieve or create a Neo4jVector. Exiting.")
Expand Down Expand Up @@ -133,26 +133,42 @@ def get_results(question) -> str:

return result


# Using the vector store directly. But this could blow out the token count
# @retry(tries=5, delay=5)
# def get_results(question)-> str:
# """Generate response using Neo4jVector using vector index only

# Args:
# question (str): User query

# Returns:
# str: Formatted string answer with citations, if available.
# """
@retry(tries=2, delay=5)
def get_results_minimized_tokens(question) -> str:
"""Generate response using Neo4jVector with minimized token usage

# logging.info(f'Using Neo4j url: {url}')
Args:
question (str): User query

# # Returns a dict with keys: answer, sources
# vector_result = vector_store.similarity_search(question, k=3)
Returns:
str: Formatted string answer with citations, if available.
"""
logging.info(f"Using Neo4j url: {url}")

# logging.debug(f'chain_result: {vector_result}')
vector_result = vector_store.similarity_search(question, k=3)
context = {
"input": question,
"context": [doc.page_content for doc in vector_result]
}

# result = vector_result
chain_result = vector_chain.invoke(
{"question": question, "context": dumps(context)},
prompt=VECTOR_PROMPT,
return_only_outputs=True,
)

logging.debug(f"chain_result: {chain_result}")

# return result
result = chain_result["answer"]

# Cite sources, if any
sources = chain_result["sources"]
sources_split = sources.split(", ")
for source in sources_split:
if source != "" and source != "N/A" and source != "None":
result += f"\n - [{source}]({source})"

return result