diff --git a/Home.py b/Home.py index 0582b77..0357d74 100644 --- a/Home.py +++ b/Home.py @@ -11,7 +11,7 @@ from template.feed_page import feed_template from template.login_page import login_template, signup_template from template.summary_page import summary_template -from utils import load_activities, second_to_text, signout +from utils import load_activities, second_to_text, signout, load_local_embedding_model st.set_page_config(page_title="", page_icon="👋", layout="wide") @@ -71,6 +71,7 @@ key_dict = json.loads(st.secrets["textkey"]) creds = service_account.Credentials.from_service_account_info(key_dict) st.session_state["firestore_db"] = firestore.Client(credentials=creds) +load_local_embedding_model() # st.session_state["firestore_db"] = firestore.Client.from_service_account_json("assets/newsgpt_firebase_serviceAccount.json") if st.session_state.get("error", None): diff --git a/requirements.txt b/requirements.txt index a49b5d7..fe62fa5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ firebase_admin==6.2.0 -llama_index==0.8.16 +llama_index==0.8.51 pandas==2.0.3 protobuf==4.24.3 readtime==3.0.0 diff --git a/template/chat_page.py b/template/chat_page.py index a1d47c7..4215a8e 100644 --- a/template/chat_page.py +++ b/template/chat_page.py @@ -1,7 +1,9 @@ +import time import readtime import streamlit as st from streamlit_extras.row import row from streamlit_extras.switch_page_button import switch_page +from streamlit_extras.streaming_write import write from utils import ( generate_anno_text, @@ -191,23 +193,33 @@ def chat_template(): with st.chat_message( "assistant", ): + # start = time.time() with st.spinner("Thinking..."): # response = st.session_state["chat_engine"].chat(prompt) try: response = st.session_state["chat_engine"].query(prompt) - response = response.response - except: - response = "Unable to retrieve the result due to some unexpected reason." - st.write(response) - message = {"role": "assistant", "content": response} - st.session_state.reading_time += readtime.of_text( - response - ).seconds - st.session_state.messages.append( - message + except Exception as e: + response = f"Unable to retrieve the result due to some unexpected reason, {e}" + print(response) + # st.write(response) + if not isinstance(response, str): + with st.spinner("Answering..."): + buffer = "" + def stream_example(): + nonlocal buffer + for word in response.response_gen: + buffer += word + yield word + write(stream_example) + response = buffer + message = {"role": "assistant", "content": response} + st.session_state.reading_time += readtime.of_text( + response + ).seconds + st.session_state.messages.append( + message ) # Add response to message history - - print(st.session_state.reading_time) + # print(f"think time: {time.time() - start}") predefine_prompt_row = row( [0.05, 0.15, 0.15, 0.3, 0.15, 0.15, 0.05], vertical_align="center", diff --git a/utils.py b/utils.py index 0608b93..fc515b0 100644 --- a/utils.py +++ b/utils.py @@ -18,9 +18,39 @@ from streamlit_extras.customize_running import center_running from streamlit_extras.row import row from streamlit_extras.switch_page_button import switch_page - +from llama_index.text_splitter import TokenTextSplitter +from llama_index.node_parser import SimpleNodeParser from config import NEWS_CATEGORIES +@st.cache_resource +def load_local_embedding_model(): + # print("start load_local_embedding_model") + # from llama_index.embeddings import TextEmbeddingsInference + + # st.session_state["local_embed_model"] = TextEmbeddingsInference( + # model_name="BAAI/bge-large-en-v1.5", + # base_url = "http://127.0.0.1:5007", + # timeout=60, # timeout in seconds + # embed_batch_size=10, # batch size for embedding + # ) + # print(st.session_state["local_embed_model"]) + if "service_context" not in st.session_state: + st.session_state["service_context"] = ServiceContext.from_defaults( + llm=OpenAI( + model="gpt-3.5-turbo", + temperature=0.2, + chunk_size=1024, + chunk_overlap=100, + system_prompt="As an expert current affairs commentator and analyst,\ + your task is to summarize the articles and answer the questions from the user related to the news articles", + ), + # callback_manager=callback_manager + # embed_model=st.session_state["local_embed_model"], + chunk_size=256, + chunk_overlap=20 + ) + # print("finish load_local_embedding_model") + def hash_text(text: str): hash_object = hashlib.sha256(text.encode()) return hash_object.hexdigest() @@ -661,6 +691,11 @@ def summary_layout_template( #TODO: Enhance the performance of retriever. Current retriever sometimes can't get the answer correctly (Change to other retriever) def run_chat(payload, query_embed, ori_article_id, compare_num=5): + + if "service_context" not in st.session_state:# "local_embed_model" not in st.session_state: + load_local_embedding_model.clear() + load_local_embedding_model() + st.session_state.messages = [ {"role": "assistant", "content": f"Ask me a question about {payload['title']}"} ] @@ -719,21 +754,36 @@ def run_chat(payload, query_embed, ori_article_id, compare_num=5): # from llama_index.callbacks import CallbackManager, LlamaDebugHandler # llama_debug = LlamaDebugHandler(print_trace_on_end=True) # callback_manager = CallbackManager([llama_debug]) - if "service_context" not in st.session_state: - st.session_state["service_context"] = ServiceContext.from_defaults( - llm=OpenAI( - model="gpt-3.5-turbo", - temperature=0.2, - chunk_size=1024, - chunk_overlap=100, - system_prompt="As an expert current affairs commentator and analyst,\ - your task is to summarize the articles and answer the questions from the user related to the news articles", - ), - # callback_manager=callback_manager - ) - st.session_state["chat_engine"] = VectorStoreIndex.from_documents( - documents, use_async=True, service_context=st.session_state.service_context - ).as_query_engine() + text_splitter = TokenTextSplitter(separator=" ", chunk_size=256, chunk_overlap=20) + #create node parser to parse nodes from document + node_parser = SimpleNodeParser(text_splitter=text_splitter) + + # if "service_context" not in st.session_state: + # st.session_state["service_context"] = ServiceContext.from_defaults( + # llm=OpenAI( + # model="gpt-3.5-turbo", + # temperature=0.2, + # chunk_size=1024, + # chunk_overlap=100, + # system_prompt="As an expert current affairs commentator and analyst,\ + # your task is to summarize the articles and answer the questions from the user related to the news articles", + # ), + # # callback_manager=callback_manager + # embed_model=st.session_state["local_embed_model"], + # chunk_size=256, + # chunk_overlap=20 + # ) + + nodes = node_parser.get_nodes_from_documents(documents) + print(f"loaded nodes with {len(nodes)} nodes") + index = VectorStoreIndex( + nodes=nodes, + service_context=st.session_state["service_context"] + ) + st.session_state["chat_engine"] = index.as_query_engine(streaming=True) + # st.session_state["chat_engine"] = VectorStoreIndex.from_documents( + # documents, use_async=True, service_context=st.session_state.service_context + # ).as_query_engine() # print("Prepare summary index: {}".format(time.time()-start))