Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Free usage cap added #24

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions rag_demo/free_use_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Manages number of free questions permitted before forcing user to supply their own OpenAI Key

import streamlit as st

def _setup_free_questions_count():
if "FREE_QUESTIONS_REMAINING" not in st.session_state:
try:
st.session_state["FREE_QUESTIONS_REMAINING"] = st.secrets["FREE_QUESTIONS_PER_SESSION"]
except:
st.session_state["FREE_QUESTIONS_REMAINING"] = 3

def free_questions_exhausted()-> bool:

_setup_free_questions_count()

remaining = st.session_state["FREE_QUESTIONS_REMAINING"]
return remaining <= 0

def user_supplied_openai_key_unavailable()-> bool:
if "USER_OPENAI_KEY" not in st.session_state:
return True
uok = st.session_state["USER_OPENAI_KEY"]
if uok is None or uok == "":
return True
return False

def decrement_free_questions():

_setup_free_questions_count()

remaining = st.session_state["FREE_QUESTIONS_REMAINING"]
if remaining > 0:
st.session_state["FREE_QUESTIONS_REMAINING"] = remaining - 1
7 changes: 5 additions & 2 deletions rag_demo/graph_cypher_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@
url = st.secrets["NEO4J_URI"]
username = st.secrets["NEO4J_USERNAME"]
password = st.secrets["NEO4J_PASSWORD"]
openai_key = st.secrets["OPENAI_API_KEY"]
llm_key = st.secrets["OPENAI_API_KEY"]

if "USER_OPENAI_API_KEY" in st.session_state:
openai_key = st.session_state["USER_OPENAI_API_KEY"]
else:
openai_key = st.secrets["OPENAI_API_KEY"]

graph = Neo4jGraph(
url=url,
Expand Down
8 changes: 8 additions & 0 deletions rag_demo/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from analytics import track
from free_use_manager import free_questions_exhausted, user_supplied_openai_key_unavailable, decrement_free_questions
from langchain.globals import set_llm_cache
from langchain.cache import InMemoryCache
from langchain_community.callbacks import HumanApprovalCallbackHandler
Expand All @@ -12,6 +13,7 @@

# Anonymous Session Analytics
if "SESSION_ID" not in st.session_state:
# Track method will create and add session id to state on first run
track(
"rag_demo",
"appStarted",
Expand Down Expand Up @@ -44,6 +46,10 @@
st.markdown(message["content"], unsafe_allow_html=True)

# User input - switch between sidebar sample quick select or actual user input. Clunky but works.
if free_questions_exhausted() and user_supplied_openai_key_unavailable():
st.warning("Thank you for trying out the Neo4j Rag Demo. Please input your OpenAI Key in the sidebar to continue asking questions.")
st.stop()

if "sample" in st.session_state and st.session_state["sample"] is not None:
user_input = st.session_state["sample"]
else:
Expand Down Expand Up @@ -86,6 +92,8 @@
new_message = {"role": "ai", "content": content}
st.session_state.messages.append(new_message)

decrement_free_questions()

message_placeholder.markdown(content)

# Reinsert user chat input if sample quick select was previously used.
Expand Down
8 changes: 8 additions & 0 deletions rag_demo/sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ def ChangeButtonColour(wgt_txt, wch_hex_colour = '12px'):

def sidebar():
with st.sidebar:

with st.expander("OpenAI Key"):
new_oak = st.text_input("Your OpenAI API Key")
# if "USER_OPENAI_KEY" not in st.session_state:
# st.session_state["USER_OPENAI_KEY"] = new_oak
# else:
st.session_state["USER_OPENAI_KEY"] = new_oak

st.markdown(f"""This the schema in which the EDGAR filings are stored in Neo4j: \n <img style="width: 70%; height: auto;" src="{SCHEMA_IMG_PATH}"/>""", unsafe_allow_html=True)

st.markdown(f"""This is how the Chatbot flow goes: \n <img style="width: 70%; height: auto;" src="{LANGCHAIN_IMG_PATH}"/>""", unsafe_allow_html=True)
Expand Down
11 changes: 8 additions & 3 deletions rag_demo/vector_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
input_variables=["input","context"], template=VECTOR_PROMPT_TEMPLATE
)

EMBEDDING_MODEL = OpenAIEmbeddings()
if "USER_OPENAI_API_KEY" in st.session_state:
openai_key = st.session_state["USER_OPENAI_API_KEY"]
else:
openai_key = st.secrets["OPENAI_API_KEY"]

EMBEDDING_MODEL = OpenAIEmbeddings(openai_api_key=openai_key)
MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)

index_name = "form_10k_chunks"
Expand Down Expand Up @@ -78,7 +83,7 @@
vector_retriever = vector_store.as_retriever()

vector_chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(temperature=0),
ChatOpenAI(temperature=0, openai_api_key=openai_key),
chain_type="stuff",
retriever=vector_retriever,
memory=MEMORY,
Expand Down Expand Up @@ -119,7 +124,7 @@ def get_results(question)-> str:

return result

# Using the vector store directly. But this will blow out the token count
# Using the vector store directly. But this could blow out the token count
# @retry(tries=5, delay=5)
# def get_results(question)-> str:
# """Generate response using Neo4jVector using vector index only
Expand Down
10 changes: 7 additions & 3 deletions rag_demo/vector_graph_chain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from json import loads, dumps
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chains import RetrievalQAWithSourcesChain
Expand All @@ -22,7 +21,12 @@
input_variables=["question"], template=VECTOR_GRAPH_PROMPT_TEMPLATE
)

EMBEDDING_MODEL = OpenAIEmbeddings()
if "USER_OPENAI_API_KEY" in st.session_state:
openai_key = st.session_state["USER_OPENAI_API_KEY"]
else:
openai_key = st.secrets["OPENAI_API_KEY"]

EMBEDDING_MODEL = OpenAIEmbeddings(openai_api_key=openai_key)
MEMORY = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)

index_name = "form_10k_chunks"
Expand Down Expand Up @@ -93,7 +97,7 @@
vector_graph_retriever = vector_store.as_retriever()

vector_graph_chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(temperature=0),
ChatOpenAI(temperature=0, openai_api_key=openai_key),
chain_type="stuff",
retriever=vector_graph_retriever,
memory=MEMORY,
Expand Down