Skip to content

Commit

Permalink
support chat streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
timho102003 committed Oct 26, 2023
1 parent f0c0b09 commit 5c10b10
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 30 deletions.
3 changes: 2 additions & 1 deletion Home.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from template.feed_page import feed_template
from template.login_page import login_template, signup_template
from template.summary_page import summary_template
from utils import load_activities, second_to_text, signout
from utils import load_activities, second_to_text, signout, load_local_embedding_model

st.set_page_config(page_title="", page_icon="👋", layout="wide")

Expand Down Expand Up @@ -71,6 +71,7 @@
key_dict = json.loads(st.secrets["textkey"])
creds = service_account.Credentials.from_service_account_info(key_dict)
st.session_state["firestore_db"] = firestore.Client(credentials=creds)
load_local_embedding_model()
# st.session_state["firestore_db"] = firestore.Client.from_service_account_json("assets/newsgpt_firebase_serviceAccount.json")

if st.session_state.get("error", None):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
firebase_admin==6.2.0
llama_index==0.8.16
llama_index==0.8.51
pandas==2.0.3
protobuf==4.24.3
readtime==3.0.0
Expand Down
36 changes: 24 additions & 12 deletions template/chat_page.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import time
import readtime
import streamlit as st
from streamlit_extras.row import row
from streamlit_extras.switch_page_button import switch_page
from streamlit_extras.streaming_write import write

from utils import (
generate_anno_text,
Expand Down Expand Up @@ -191,23 +193,33 @@ def chat_template():
with st.chat_message(
"assistant",
):
# start = time.time()
with st.spinner("Thinking..."):
# response = st.session_state["chat_engine"].chat(prompt)
try:
response = st.session_state["chat_engine"].query(prompt)
response = response.response
except:
response = "Unable to retrieve the result due to some unexpected reason."
st.write(response)
message = {"role": "assistant", "content": response}
st.session_state.reading_time += readtime.of_text(
response
).seconds
st.session_state.messages.append(
message
except Exception as e:
response = f"Unable to retrieve the result due to some unexpected reason, {e}"
print(response)
# st.write(response)
if not isinstance(response, str):
with st.spinner("Answering..."):
buffer = ""
def stream_example():
nonlocal buffer
for word in response.response_gen:
buffer += word
yield word
write(stream_example)
response = buffer
message = {"role": "assistant", "content": response}
st.session_state.reading_time += readtime.of_text(
response
).seconds
st.session_state.messages.append(
message
) # Add response to message history

print(st.session_state.reading_time)
# print(f"think time: {time.time() - start}")
predefine_prompt_row = row(
[0.05, 0.15, 0.15, 0.3, 0.15, 0.15, 0.05],
vertical_align="center",
Expand Down
82 changes: 66 additions & 16 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,39 @@
from streamlit_extras.customize_running import center_running
from streamlit_extras.row import row
from streamlit_extras.switch_page_button import switch_page

from llama_index.text_splitter import TokenTextSplitter
from llama_index.node_parser import SimpleNodeParser
from config import NEWS_CATEGORIES

@st.cache_resource
def load_local_embedding_model():
# print("start load_local_embedding_model")
# from llama_index.embeddings import TextEmbeddingsInference

# st.session_state["local_embed_model"] = TextEmbeddingsInference(
# model_name="BAAI/bge-large-en-v1.5",
# base_url = "http://127.0.0.1:5007",
# timeout=60, # timeout in seconds
# embed_batch_size=10, # batch size for embedding
# )
# print(st.session_state["local_embed_model"])
if "service_context" not in st.session_state:
st.session_state["service_context"] = ServiceContext.from_defaults(
llm=OpenAI(
model="gpt-3.5-turbo",
temperature=0.2,
chunk_size=1024,
chunk_overlap=100,
system_prompt="As an expert current affairs commentator and analyst,\
your task is to summarize the articles and answer the questions from the user related to the news articles",
),
# callback_manager=callback_manager
# embed_model=st.session_state["local_embed_model"],
chunk_size=256,
chunk_overlap=20
)
# print("finish load_local_embedding_model")

def hash_text(text: str):
hash_object = hashlib.sha256(text.encode())
return hash_object.hexdigest()
Expand Down Expand Up @@ -661,6 +691,11 @@ def summary_layout_template(

#TODO: Enhance the performance of retriever. Current retriever sometimes can't get the answer correctly (Change to other retriever)
def run_chat(payload, query_embed, ori_article_id, compare_num=5):

if "service_context" not in st.session_state:# "local_embed_model" not in st.session_state:
load_local_embedding_model.clear()
load_local_embedding_model()

st.session_state.messages = [
{"role": "assistant", "content": f"Ask me a question about {payload['title']}"}
]
Expand Down Expand Up @@ -719,21 +754,36 @@ def run_chat(payload, query_embed, ori_article_id, compare_num=5):
# from llama_index.callbacks import CallbackManager, LlamaDebugHandler
# llama_debug = LlamaDebugHandler(print_trace_on_end=True)
# callback_manager = CallbackManager([llama_debug])
if "service_context" not in st.session_state:
st.session_state["service_context"] = ServiceContext.from_defaults(
llm=OpenAI(
model="gpt-3.5-turbo",
temperature=0.2,
chunk_size=1024,
chunk_overlap=100,
system_prompt="As an expert current affairs commentator and analyst,\
your task is to summarize the articles and answer the questions from the user related to the news articles",
),
# callback_manager=callback_manager
)
st.session_state["chat_engine"] = VectorStoreIndex.from_documents(
documents, use_async=True, service_context=st.session_state.service_context
).as_query_engine()
text_splitter = TokenTextSplitter(separator=" ", chunk_size=256, chunk_overlap=20)
#create node parser to parse nodes from document
node_parser = SimpleNodeParser(text_splitter=text_splitter)

# if "service_context" not in st.session_state:
# st.session_state["service_context"] = ServiceContext.from_defaults(
# llm=OpenAI(
# model="gpt-3.5-turbo",
# temperature=0.2,
# chunk_size=1024,
# chunk_overlap=100,
# system_prompt="As an expert current affairs commentator and analyst,\
# your task is to summarize the articles and answer the questions from the user related to the news articles",
# ),
# # callback_manager=callback_manager
# embed_model=st.session_state["local_embed_model"],
# chunk_size=256,
# chunk_overlap=20
# )

nodes = node_parser.get_nodes_from_documents(documents)
print(f"loaded nodes with {len(nodes)} nodes")
index = VectorStoreIndex(
nodes=nodes,
service_context=st.session_state["service_context"]
)
st.session_state["chat_engine"] = index.as_query_engine(streaming=True)
# st.session_state["chat_engine"] = VectorStoreIndex.from_documents(
# documents, use_async=True, service_context=st.session_state.service_context
# ).as_query_engine()

# print("Prepare summary index: {}".format(time.time()-start))

Expand Down

0 comments on commit 5c10b10

Please sign in to comment.