Skip to content

Commit

Permalink
Refactor app.py and frontend.py for improved code organization and re…
Browse files Browse the repository at this point in the history
…adability
  • Loading branch information
samadpls committed Oct 31, 2024
1 parent 474d5c4 commit 6f8a7b3
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 146 deletions.
185 changes: 75 additions & 110 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
import streamlit as st
from deta import Deta
import sys
import os
import json
from backend import (
configure_page_styles,
create_oauth2_component,
display_github_badge,
handle_google_login_if_needed,
hide_main_menu_and_footer,
)
from frontend import (
Expand All @@ -19,124 +17,91 @@
handle_new_chat,
)
from model import create_huggingface_hub


sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from src.auth import *
from src.constant import *

def format_chat_history(messages):
"""Format the chat history as a structured JSON string."""
history = []
for msg in messages[1:]:
content = msg['content']
if '```sql' in content:
content = content.replace('```sql\n', '').replace('\n```', '').strip()

history.append({
"role": msg['role'],
"query" if msg['role'] == 'user' else "response": content
})

formatted_history = json.dumps(history, indent=2)
print("Formatted history:", formatted_history)
return formatted_history

def extract_sql_code(response):
"""Extract clean SQL code from the response."""
sql_code_start = response.find("```sql")
if sql_code_start != -1:
sql_code_end = response.find("```", sql_code_start + 5)
if sql_code_end != -1:
sql_code = response[sql_code_start + 6:sql_code_end].strip()
return f"```sql\n{sql_code}\n```"
return response

def main():
"""Main function to configure and run the Querypls application."""
configure_page_styles("static/css/styles.css")
deta = Deta(DETA_PROJECT_KEY)

if "model" not in st.session_state:
llm = create_huggingface_hub()
st.session_state["model"] = llm
db = deta.Base("users")
oauth2 = create_oauth2_component()

if "code" not in st.session_state or not st.session_state.code:
st.session_state.code = False

if "code" not in st.session_state:
st.session_state.code = False


if "messages" not in st.session_state:
create_message()

hide_main_menu_and_footer()
if st.session_state.code == False:
col1, col2, col3 = st.columns(3)
with col1:
pass
with col2:
with st.container():

display_github_badge()
display_logo_and_heading()

st.markdown("`Made with 🤍`")
if "token" not in st.session_state:
result = oauth2.authorize_button(
"Connect with Google",
REDIRECT_URI,
SCOPE,
icon="data:image/svg+xml;charset=utf-8,%3Csvg \
xmlns='http://www.w3.org/2000/svg' \
xmlns:xlink='http://www.w3.org/1999/xlink' \
viewBox='0 0 48 48'%3E%3Cdefs%3E%3Cpath id='a' \
d='M44.5 20H24v8.5h11.8C34.7 33.9 30.1 37 24 37c-7.2 \
0-13-5.8-13-13s5.8-13 13-13c3.1 0 5.9 1.1 8.1 \
2.9l6.4-6.4C34.6 4.1 29.6 2 24 2 11.8 2 2 11.8 2 \
24s9.8 22 22 22c11 0 21-8 21-22 \
0-1.3-.2-2.7-.5-4z'/%3E%3C/defs%3E%3CclipPath \
id='b'%3E%3Cuse xlink:href='%23a' \
overflow='visible'/%3E%3C/clipPath%3E%3Cpath \
clip-path='url(%23b)' fill='%23FBBC05' \
d='M0 37V11l17 13z'/%3E%3Cpath clip-path='url(%23b)' \
fill='%23EA4335' d='M0 11l17 13 7-6.1L48 \
14V0H0z'/%3E%3Cpath clip-path='url(%23b)' \
fill='%2334A853' d='M0 37l30-23 7.9 1L48 \
0v48H0z'/%3E%3Cpath clip-path='url(%23b)' \
fill='%234285F4' d='M48 48L17 24l-4-3 \
35-10z'/%3E%3C/svg%3E",
use_container_width=True,
)
handle_google_login_if_needed(result)
if st.session_state.code:
st.rerun()
with col3:
pass
else:
with st.sidebar:
display_github_badge()
display_logo_and_heading()
st.markdown("`Made with 🤍`")
if st.session_state.code:
handle_new_chat(db)
if st.session_state.code:
display_previous_chats(db)

if "messages" not in st.session_state:
create_message()
display_welcome_message()

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"], unsafe_allow_html=True)

if prompt := st.chat_input(disabled=(st.session_state.code is False)):
st.session_state.messages.append(
{"role": "user", "content": prompt}
)
with st.chat_message("user"):
st.write(prompt)

prompt_template = PromptTemplate(
template=TEMPLATE, input_variables=["input"]
)

if "model" in st.session_state:
llm_chain = (
prompt_template
| st.session_state.model
| StrOutputParser()
)
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Generating..."):
response = llm_chain.invoke({"input": prompt})
index = response.find("```")
if index != -1:
st.markdown(response[index:])
else:
st.markdown(response)
message = {
"role": "assistant",
"content": response,
}
st.session_state.messages.append(message)



with st.sidebar:
display_github_badge()
display_logo_and_heading()
st.markdown("`Made with 🤍`")
handle_new_chat()

display_welcome_message()
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)

conversation_history = format_chat_history(st.session_state.messages)
prompt_template = PromptTemplate(
template=TEMPLATE,
input_variables=["input", "conversation_history"]
)

if "model" in st.session_state:
llm_chain = prompt_template | st.session_state.model | StrOutputParser()

with st.chat_message("assistant"):
with st.spinner("Generating..."):
response = llm_chain.invoke({
"input": prompt,
"conversation_history": conversation_history
})

# Clean and format the response
formatted_response = extract_sql_code(response)
st.markdown(formatted_response)

# Add to chat history
st.session_state.messages.append({
"role": "assistant",
"content": formatted_response
})

if __name__ == "__main__":
main()
main()
74 changes: 38 additions & 36 deletions src/frontend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import streamlit as st
from src.database import database, get_previous_chats


def display_logo_and_heading():
"""Displays the Querypls logo."""
Expand All @@ -14,77 +12,81 @@ def display_welcome_message():
st.markdown(f"#### Welcome to \n ## 🗃️💬Querypls - Prompt to SQL")


def handle_new_chat(db, max_chat_histories=5):
def handle_new_chat(max_chat_histories=5):
"""Handles the initiation of a new chat session.
Displays the remaining chat history count and provides a button to start a new chat.
Args:
db: Deta Base instance.
max_chat_histories (int, optional): Maximum number of chat histories to retain.
Returns:
None
"""
remaining_chats = max_chat_histories - len(
get_previous_chats(db, st.session_state.user_email)
)
remaining_chats = max_chat_histories - len(st.session_state.get("previous_chats", []))
st.markdown(
f" #### Remaining Chat Histories: \
`{remaining_chats}/{max_chat_histories}`"
f" #### Remaining Chat Histories: `{remaining_chats}/{max_chat_histories}`"
)
st.markdown(
"You can create up to 5 chat histories. Each history \
can contain unlimited messages."
"You can create up to 5 chat histories. Each history can contain unlimited messages."
)

if st.button("➕ New chat"):
database(db, previous_key=st.session_state.key)
save_chat_history() # Save current chat before creating a new one
create_message()


def display_previous_chats(db):
"""Displays previous chat records.
def display_previous_chats():
"""Displays previous chat records stored in session state.
Retrieves and displays a list of previous chat records for the user.
Allows the user to select a chat to view.
Args:
db: Deta Base instance.
Returns:
None
"""
previous_chats = get_previous_chats(db, st.session_state.user_email)
reversed_chats = reversed(previous_chats)
if "previous_chats" in st.session_state:
reversed_chats = reversed(st.session_state["previous_chats"])

for chat in reversed_chats:
if st.button(chat["title"], key=chat["key"]):
update_session_state(db, chat)
for chat in reversed_chats:
if st.button(chat["title"], key=chat["key"]):
update_session_state(chat)


def create_message():
"""Creates a default assistant message and initializes a session key."""

st.session_state["messages"] = [
{"role": "assistant", "content": "How may I help you?"}
]
st.session_state["key"] = "key"
return


def update_session_state(db, chat):
def update_session_state(chat):
"""Updates the session state with selected chat information.
Args:
db: Deta Base instance.
chat (dict): Selected chat information.
Returns:
None
"""
previous_chat = st.session_state["messages"]
previous_key = st.session_state["key"]
st.session_state["messages"] = chat["chat"]
st.session_state["key"] = chat["key"]
database(db, previous_key, previous_chat)


def save_chat_history():
"""Saves the current chat to session state if it contains messages."""
if "messages" in st.session_state and len(st.session_state["messages"]) > 1:
# Initialize previous chats list if it doesn't exist
if "previous_chats" not in st.session_state:
st.session_state["previous_chats"] = []

# Create a chat summary to store in session
title = st.session_state["messages"][1]["content"]
chat_summary = {
"title": title[:25] + "....." if len(title) > 25 else title,
"chat": st.session_state["messages"],
"key": f"chat_{len(st.session_state['previous_chats']) + 1}"
}

st.session_state["previous_chats"].append(chat_summary)

# Limit chat histories to a maximum number
if len(st.session_state["previous_chats"]) > 5:
st.session_state["previous_chats"].pop(0) # Remove oldest chat
st.warning(
f"The oldest chat history has been removed as you reached the limit of 5 chat histories."
)

0 comments on commit 6f8a7b3

Please sign in to comment.