Skip to content

Commit

Permalink
Added new files and functions for Hugging Face Hub integration, backe…
Browse files Browse the repository at this point in the history
…nd testing, and authentication

Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Feb 12, 2024
1 parent 0b4aa7a commit 383266e
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 233 deletions.
Empty file added src/__init__.py
Empty file.
237 changes: 4 additions & 233 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,248 +1,19 @@
from langchain.chains import LLMChain
from langchain.llms import HuggingFaceHub
from langchain.prompts import PromptTemplate
import streamlit as st
from streamlit_oauth import OAuth2Component
from deta import Deta
import sys
import time
import os
import json
import base64
from backend import configure_page_styles, create_oauth2_component, display_github_badge, handle_google_login_if_needed, hide_main_menu_and_footer
from frontend import create_message, display_logo_and_heading, display_previous_chats, display_welcome_message, handle_new_chat
from model import create_huggingface_hub


sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from src.auth import *
from src.constant import *


def configure_page_styles(file_name):
"""Configures Streamlit page styles for Querypls.
Sets page title, icon, and applies custom CSS styles.
Hides Streamlit main menu and footer for a cleaner interface.
Note:
Ensure 'static/css/styles.css' exists with desired styles.
"""
st.set_page_config(page_title="Querypls", page_icon="💬",layout="wide",)
with open(file_name) as f:
st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)

hide_streamlit_style = (
"""<style>#MainMenu {visibility: hidden;}footer {visibility: hidden;}</style>"""
)
st.markdown(hide_streamlit_style, unsafe_allow_html=True)


def display_logo_and_heading():
"""Displays the Querypls logo."""
st.image("static/image/logo.png")


def get_previous_chats(db, user_email):
"""Fetches previous chat records for a user from the database.
Args:
db: Deta Base instance.
user_email (str): User's email address.
Returns:
list: List of previous chat records.
"""
return db.fetch({"email": user_email}).items


def display_previous_chats(db):
"""Displays previous chat records.
Retrieves and displays a list of previous chat records for the user.
Allows the user to select a chat to view.
Args:
db: Deta Base instance.
Returns:
None
"""
previous_chats = get_previous_chats(db, st.session_state.user_email)
reversed_chats = reversed(previous_chats)

for chat in reversed_chats:
if st.button(chat["title"], key=chat["key"]):
update_session_state(db, chat)


def update_session_state(db, chat):
"""Updates the session state with selected chat information.
Args:
db: Deta Base instance.
chat (dict): Selected chat information.
Returns:
None
"""
previous_chat = st.session_state["messages"]
previous_key = st.session_state["key"]
st.session_state["messages"] = chat["chat"]
st.session_state["key"] = chat["key"]
database(db, previous_key, previous_chat)


def database(db, previous_key="key", previous_chat=None, max_chat_histories=5):
"""Manages user chat history in the database.
Updates, adds, or removes chat history based on user interaction.
Args:
db: Deta Base instance.
previous_key (str): Key for the previous chat in the database.
previous_chat (list, optional): Previous chat messages.
max_chat_histories (int, optional): Maximum number of chat histories to retain.
Returns:
None
"""
user_email = st.session_state.user_email
previous_chats = get_previous_chats(db, user_email)
existing_chat = db.get(previous_key) if previous_key != "key" else None
if (
previous_chat is not None
and existing_chat is not None
and previous_key != "key"
):
new_messages = [
message for message in previous_chat if message not in existing_chat["chat"]
]
existing_chat["chat"].extend(new_messages)
db.update({"chat": existing_chat["chat"]}, key=previous_key)
return
previous_chat = (
st.session_state.messages if previous_chat is None else previous_chat
)
if len(previous_chat) > 1 and previous_key == "key":
title = previous_chat[1]["content"]
db.put(
{
"email": user_email,
"chat": previous_chat,
"title": title[:25] + "....." if len(title) > 25 else title,
}
)

if len(previous_chats) >= max_chat_histories:
db.delete(previous_chats[0]["key"])
st.warning(
f"Chat '{previous_chats[0]['title']}' has been removed as you reached the limit of {max_chat_histories} chat histories."
)


def create_message():
"""Creates a default assistant message and initializes a session key."""

st.session_state["messages"] = [
{"role": "assistant", "content": "How may I help you?"}
]
st.session_state["key"] = "key"
return


def handle_new_chat(db, max_chat_histories=5):
"""Handles the initiation of a new chat session.
Displays the remaining chat history count and provides a button to start a new chat.
Args:
db: Deta Base instance.
max_chat_histories (int, optional): Maximum number of chat histories to retain.
Returns:
None
"""
remaining_chats = max_chat_histories - len(
get_previous_chats(db, st.session_state.user_email)
)
st.markdown(f" #### Remaining Chats: `{remaining_chats}/{max_chat_histories}`")
if st.button("➕ New chat"):
database(db, previous_key=st.session_state.key)
create_message()


def create_huggingface_hub():
"""Creates an instance of Hugging Face Hub with specified configurations.
Returns:
HuggingFaceHub: Instance of Hugging Face Hub.
"""
return HuggingFaceHub(
huggingfacehub_api_token=HUGGINGFACE_API_TOKEN,
repo_id=REPO_ID,
model_kwargs={"temperature": 0.2, "max_new_tokens": 180},
)


def hide_main_menu_and_footer():
"""Hides the Streamlit main menu and footer for a cleaner interface."""
st.markdown(
"""<style>#MainMenu {visibility: hidden;}footer {visibility: hidden;}</style>""",
unsafe_allow_html=True,
)


def display_github_badge():
"""Displays a GitHub badge with a link to the Querypls repository."""
st.markdown(
"""<a href='https://github.com/samadpls/Querypls'><img src='https://img.shields.io/github/stars/samadpls/querypls?color=red&label=star%20me&logoColor=red&style=social'></a>""",
unsafe_allow_html=True,
)


def handle_google_login_if_needed(result):
"""Handles Google login if it has not been run yet.
Args:
result (str): Authorization code received from Google.
Returns:
None
"""
try:
if result and "token" in result:
st.session_state.token = result.get("token")
token = st.session_state["token"]
id_token = token["id_token"]
payload = id_token.split(".")[1]
payload += "=" * (-len(payload) % 4)
payload = json.loads(base64.b64decode(payload))
email = payload["email"]
st.session_state.user_email = email
st.session_state.code = True
return
except:
st.warning(
"Seems like there is a network issue. Please check your internet connection."
)
sys.exit()


def display_welcome_message():
"""Displays a welcome message based on user chat history."""
no_chat_history = len(st.session_state.messages) == 1
if no_chat_history:
st.markdown(f"#### Welcome to \n ## 🛢💬Querypls - Prompt to SQL")


def create_oauth2_component():
return OAuth2Component(
CLIENT_ID,
CLIENT_SECRET,
AUTHORIZE_URL,
TOKEN_URL,
REFRESH_TOKEN_URL,
REVOKE_TOKEN_URL,
)

def main():
"""Main function to configure and run the Querypls application."""
configure_page_styles('static/css/styles.css')
Expand Down
82 changes: 82 additions & 0 deletions src/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import streamlit as st
from streamlit_oauth import OAuth2Component
import sys
import os
import json
import base64

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from src.auth import *
from src.constant import *

def configure_page_styles(file_name):
"""Configures Streamlit page styles for Querypls.
Sets page title, icon, and applies custom CSS styles.
Hides Streamlit main menu and footer for a cleaner interface.
Note:
Ensure 'static/css/styles.css' exists with desired styles.
"""
st.set_page_config(page_title="Querypls", page_icon="💬",layout="wide",)
with open(file_name) as f:
st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)

hide_streamlit_style = (
"""<style>#MainMenu {visibility: hidden;}footer {visibility: hidden;}</style>"""
)
st.markdown(hide_streamlit_style, unsafe_allow_html=True)

def hide_main_menu_and_footer():
"""Hides the Streamlit main menu and footer for a cleaner interface."""
st.markdown(
"""<style>#MainMenu {visibility: hidden;}footer {visibility: hidden;}</style>""",
unsafe_allow_html=True,
)

def handle_google_login_if_needed(result):
"""Handles Google login if it has not been run yet.
Args:
result (str): Authorization code received from Google.
Returns:
None
"""
try:
if result and "token" in result:
st.session_state.token = result.get("token")
token = st.session_state["token"]
id_token = token["id_token"]
payload = id_token.split(".")[1]
payload += "=" * (-len(payload) % 4)
payload = json.loads(base64.b64decode(payload))
email = payload["email"]
st.session_state.user_email = email
st.session_state.code = True
return
except:
st.warning(
"Seems like there is a network issue. Please check your internet connection."
)
sys.exit()

def display_github_badge():
"""Displays a GitHub badge with a link to the Querypls repository."""
st.markdown(
"""<a href='https://github.com/samadpls/Querypls'><img src='https://img.shields.io/github/stars/samadpls/querypls?color=red&label=star%20me&logoColor=red&style=social'></a>""",
unsafe_allow_html=True,
)


def create_oauth2_component():
return OAuth2Component(
CLIENT_ID,
CLIENT_SECRET,
AUTHORIZE_URL,
TOKEN_URL,
REFRESH_TOKEN_URL,
REVOKE_TOKEN_URL,
)

61 changes: 61 additions & 0 deletions src/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import streamlit as st

def get_previous_chats(db, user_email):
"""Fetches previous chat records for a user from the database.
Args:
db: Deta Base instance.
user_email (str): User's email address.
Returns:
list: List of previous chat records.
"""
return db.fetch({"email": user_email}).items


def database(db, previous_key="key", previous_chat=None, max_chat_histories=5):
"""Manages user chat history in the database.
Updates, adds, or removes chat history based on user interaction.
Args:
db: Deta Base instance.
previous_key (str): Key for the previous chat in the database.
previous_chat (list, optional): Previous chat messages.
max_chat_histories (int, optional): Maximum number of chat histories to retain.
Returns:
None
"""
user_email = st.session_state.user_email
previous_chats = get_previous_chats(db, user_email)
existing_chat = db.get(previous_key) if previous_key != "key" else None
if (
previous_chat is not None
and existing_chat is not None
and previous_key != "key"
):
new_messages = [
message for message in previous_chat if message not in existing_chat["chat"]
]
existing_chat["chat"].extend(new_messages)
db.update({"chat": existing_chat["chat"]}, key=previous_key)
return
previous_chat = (
st.session_state.messages if previous_chat is None else previous_chat
)
if len(previous_chat) > 1 and previous_key == "key":
title = previous_chat[1]["content"]
db.put(
{
"email": user_email,
"chat": previous_chat,
"title": title[:25] + "....." if len(title) > 25 else title,
}
)

if len(previous_chats) >= max_chat_histories:
db.delete(previous_chats[0]["key"])
st.warning(
f"Chat '{previous_chats[0]['title']}' has been removed as you reached the limit of {max_chat_histories} chat histories."
)
Loading

0 comments on commit 383266e

Please sign in to comment.