From 383266e2dac9ddcc3fc8c28a641d0cc8b2eeb10a Mon Sep 17 00:00:00 2001 From: samadpls Date: Tue, 13 Feb 2024 00:25:08 +0500 Subject: [PATCH] Added new files and functions for Hugging Face Hub integration, backend testing, and authentication Signed-off-by: samadpls --- src/__init__.py | 0 src/app.py | 237 +---------------------------------------- src/backend.py | 82 ++++++++++++++ src/database.py | 61 +++++++++++ src/frontend.py | 78 ++++++++++++++ src/model.py | 23 ++++ tests/__init__.py | 0 tests/test_auth.py | 51 +++++++++ tests/test_backend.py | 38 +++++++ tests/test_frontend.py | 71 ++++++++++++ 10 files changed, 408 insertions(+), 233 deletions(-) create mode 100644 src/__init__.py create mode 100644 src/backend.py create mode 100644 src/database.py create mode 100644 src/frontend.py create mode 100644 src/model.py create mode 100644 tests/__init__.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_backend.py create mode 100644 tests/test_frontend.py diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app.py b/src/app.py index e5d5194..8fc361a 100644 --- a/src/app.py +++ b/src/app.py @@ -1,248 +1,19 @@ from langchain.chains import LLMChain -from langchain.llms import HuggingFaceHub from langchain.prompts import PromptTemplate import streamlit as st -from streamlit_oauth import OAuth2Component from deta import Deta import sys -import time import os -import json -import base64 +from backend import configure_page_styles, create_oauth2_component, display_github_badge, handle_google_login_if_needed, hide_main_menu_and_footer +from frontend import create_message, display_logo_and_heading, display_previous_chats, display_welcome_message, handle_new_chat +from model import create_huggingface_hub + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.auth import * from src.constant import * - -def configure_page_styles(file_name): - """Configures Streamlit page styles for Querypls. - - Sets page title, icon, and applies custom CSS styles. - Hides Streamlit main menu and footer for a cleaner interface. - - Note: - Ensure 'static/css/styles.css' exists with desired styles. - """ - st.set_page_config(page_title="Querypls", page_icon="💬",layout="wide",) - with open(file_name) as f: - st.markdown(''.format(f.read()), unsafe_allow_html=True) - - hide_streamlit_style = ( - """""" - ) - st.markdown(hide_streamlit_style, unsafe_allow_html=True) - - -def display_logo_and_heading(): - """Displays the Querypls logo.""" - st.image("static/image/logo.png") - - -def get_previous_chats(db, user_email): - """Fetches previous chat records for a user from the database. - - Args: - db: Deta Base instance. - user_email (str): User's email address. - - Returns: - list: List of previous chat records. - """ - return db.fetch({"email": user_email}).items - - -def display_previous_chats(db): - """Displays previous chat records. - - Retrieves and displays a list of previous chat records for the user. - Allows the user to select a chat to view. - - Args: - db: Deta Base instance. - - Returns: - None - """ - previous_chats = get_previous_chats(db, st.session_state.user_email) - reversed_chats = reversed(previous_chats) - - for chat in reversed_chats: - if st.button(chat["title"], key=chat["key"]): - update_session_state(db, chat) - - -def update_session_state(db, chat): - """Updates the session state with selected chat information. - - Args: - db: Deta Base instance. - chat (dict): Selected chat information. - - Returns: - None - """ - previous_chat = st.session_state["messages"] - previous_key = st.session_state["key"] - st.session_state["messages"] = chat["chat"] - st.session_state["key"] = chat["key"] - database(db, previous_key, previous_chat) - - -def database(db, previous_key="key", previous_chat=None, max_chat_histories=5): - """Manages user chat history in the database. - - Updates, adds, or removes chat history based on user interaction. - - Args: - db: Deta Base instance. - previous_key (str): Key for the previous chat in the database. - previous_chat (list, optional): Previous chat messages. - max_chat_histories (int, optional): Maximum number of chat histories to retain. - - Returns: - None - """ - user_email = st.session_state.user_email - previous_chats = get_previous_chats(db, user_email) - existing_chat = db.get(previous_key) if previous_key != "key" else None - if ( - previous_chat is not None - and existing_chat is not None - and previous_key != "key" - ): - new_messages = [ - message for message in previous_chat if message not in existing_chat["chat"] - ] - existing_chat["chat"].extend(new_messages) - db.update({"chat": existing_chat["chat"]}, key=previous_key) - return - previous_chat = ( - st.session_state.messages if previous_chat is None else previous_chat - ) - if len(previous_chat) > 1 and previous_key == "key": - title = previous_chat[1]["content"] - db.put( - { - "email": user_email, - "chat": previous_chat, - "title": title[:25] + "....." if len(title) > 25 else title, - } - ) - - if len(previous_chats) >= max_chat_histories: - db.delete(previous_chats[0]["key"]) - st.warning( - f"Chat '{previous_chats[0]['title']}' has been removed as you reached the limit of {max_chat_histories} chat histories." - ) - - -def create_message(): - """Creates a default assistant message and initializes a session key.""" - - st.session_state["messages"] = [ - {"role": "assistant", "content": "How may I help you?"} - ] - st.session_state["key"] = "key" - return - - -def handle_new_chat(db, max_chat_histories=5): - """Handles the initiation of a new chat session. - - Displays the remaining chat history count and provides a button to start a new chat. - - Args: - db: Deta Base instance. - max_chat_histories (int, optional): Maximum number of chat histories to retain. - - Returns: - None - """ - remaining_chats = max_chat_histories - len( - get_previous_chats(db, st.session_state.user_email) - ) - st.markdown(f" #### Remaining Chats: `{remaining_chats}/{max_chat_histories}`") - if st.button("➕ New chat"): - database(db, previous_key=st.session_state.key) - create_message() - - -def create_huggingface_hub(): - """Creates an instance of Hugging Face Hub with specified configurations. - - Returns: - HuggingFaceHub: Instance of Hugging Face Hub. - """ - return HuggingFaceHub( - huggingfacehub_api_token=HUGGINGFACE_API_TOKEN, - repo_id=REPO_ID, - model_kwargs={"temperature": 0.2, "max_new_tokens": 180}, - ) - - -def hide_main_menu_and_footer(): - """Hides the Streamlit main menu and footer for a cleaner interface.""" - st.markdown( - """""", - unsafe_allow_html=True, - ) - - -def display_github_badge(): - """Displays a GitHub badge with a link to the Querypls repository.""" - st.markdown( - """""", - unsafe_allow_html=True, - ) - - -def handle_google_login_if_needed(result): - """Handles Google login if it has not been run yet. - - Args: - result (str): Authorization code received from Google. - - Returns: - None - """ - try: - if result and "token" in result: - st.session_state.token = result.get("token") - token = st.session_state["token"] - id_token = token["id_token"] - payload = id_token.split(".")[1] - payload += "=" * (-len(payload) % 4) - payload = json.loads(base64.b64decode(payload)) - email = payload["email"] - st.session_state.user_email = email - st.session_state.code = True - return - except: - st.warning( - "Seems like there is a network issue. Please check your internet connection." - ) - sys.exit() - - -def display_welcome_message(): - """Displays a welcome message based on user chat history.""" - no_chat_history = len(st.session_state.messages) == 1 - if no_chat_history: - st.markdown(f"#### Welcome to \n ## 🛢💬Querypls - Prompt to SQL") - - -def create_oauth2_component(): - return OAuth2Component( - CLIENT_ID, - CLIENT_SECRET, - AUTHORIZE_URL, - TOKEN_URL, - REFRESH_TOKEN_URL, - REVOKE_TOKEN_URL, - ) - def main(): """Main function to configure and run the Querypls application.""" configure_page_styles('static/css/styles.css') diff --git a/src/backend.py b/src/backend.py new file mode 100644 index 0000000..d820c42 --- /dev/null +++ b/src/backend.py @@ -0,0 +1,82 @@ +import streamlit as st +from streamlit_oauth import OAuth2Component +import sys +import os +import json +import base64 + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from src.auth import * +from src.constant import * + +def configure_page_styles(file_name): + """Configures Streamlit page styles for Querypls. + + Sets page title, icon, and applies custom CSS styles. + Hides Streamlit main menu and footer for a cleaner interface. + + Note: + Ensure 'static/css/styles.css' exists with desired styles. + """ + st.set_page_config(page_title="Querypls", page_icon="💬",layout="wide",) + with open(file_name) as f: + st.markdown(''.format(f.read()), unsafe_allow_html=True) + + hide_streamlit_style = ( + """""" + ) + st.markdown(hide_streamlit_style, unsafe_allow_html=True) + +def hide_main_menu_and_footer(): + """Hides the Streamlit main menu and footer for a cleaner interface.""" + st.markdown( + """""", + unsafe_allow_html=True, + ) + +def handle_google_login_if_needed(result): + """Handles Google login if it has not been run yet. + + Args: + result (str): Authorization code received from Google. + + Returns: + None + """ + try: + if result and "token" in result: + st.session_state.token = result.get("token") + token = st.session_state["token"] + id_token = token["id_token"] + payload = id_token.split(".")[1] + payload += "=" * (-len(payload) % 4) + payload = json.loads(base64.b64decode(payload)) + email = payload["email"] + st.session_state.user_email = email + st.session_state.code = True + return + except: + st.warning( + "Seems like there is a network issue. Please check your internet connection." + ) + sys.exit() + +def display_github_badge(): + """Displays a GitHub badge with a link to the Querypls repository.""" + st.markdown( + """""", + unsafe_allow_html=True, + ) + + +def create_oauth2_component(): + return OAuth2Component( + CLIENT_ID, + CLIENT_SECRET, + AUTHORIZE_URL, + TOKEN_URL, + REFRESH_TOKEN_URL, + REVOKE_TOKEN_URL, + ) + diff --git a/src/database.py b/src/database.py new file mode 100644 index 0000000..8c110de --- /dev/null +++ b/src/database.py @@ -0,0 +1,61 @@ +import streamlit as st + +def get_previous_chats(db, user_email): + """Fetches previous chat records for a user from the database. + + Args: + db: Deta Base instance. + user_email (str): User's email address. + + Returns: + list: List of previous chat records. + """ + return db.fetch({"email": user_email}).items + + +def database(db, previous_key="key", previous_chat=None, max_chat_histories=5): + """Manages user chat history in the database. + + Updates, adds, or removes chat history based on user interaction. + + Args: + db: Deta Base instance. + previous_key (str): Key for the previous chat in the database. + previous_chat (list, optional): Previous chat messages. + max_chat_histories (int, optional): Maximum number of chat histories to retain. + + Returns: + None + """ + user_email = st.session_state.user_email + previous_chats = get_previous_chats(db, user_email) + existing_chat = db.get(previous_key) if previous_key != "key" else None + if ( + previous_chat is not None + and existing_chat is not None + and previous_key != "key" + ): + new_messages = [ + message for message in previous_chat if message not in existing_chat["chat"] + ] + existing_chat["chat"].extend(new_messages) + db.update({"chat": existing_chat["chat"]}, key=previous_key) + return + previous_chat = ( + st.session_state.messages if previous_chat is None else previous_chat + ) + if len(previous_chat) > 1 and previous_key == "key": + title = previous_chat[1]["content"] + db.put( + { + "email": user_email, + "chat": previous_chat, + "title": title[:25] + "....." if len(title) > 25 else title, + } + ) + + if len(previous_chats) >= max_chat_histories: + db.delete(previous_chats[0]["key"]) + st.warning( + f"Chat '{previous_chats[0]['title']}' has been removed as you reached the limit of {max_chat_histories} chat histories." + ) diff --git a/src/frontend.py b/src/frontend.py new file mode 100644 index 0000000..337d08a --- /dev/null +++ b/src/frontend.py @@ -0,0 +1,78 @@ +import streamlit as st +from src.database import database, get_previous_chats + +def display_logo_and_heading(): + """Displays the Querypls logo.""" + st.image("static/image/logo.png") + +def display_welcome_message(): + """Displays a welcome message based on user chat history.""" + no_chat_history = len(st.session_state.messages) == 1 + if no_chat_history: + st.markdown(f"#### Welcome to \n ## 🛢💬Querypls - Prompt to SQL") + +def handle_new_chat(db, max_chat_histories=5): + """Handles the initiation of a new chat session. + + Displays the remaining chat history count and provides a button to start a new chat. + + Args: + db: Deta Base instance. + max_chat_histories (int, optional): Maximum number of chat histories to retain. + + Returns: + None + """ + remaining_chats = max_chat_histories - len( + get_previous_chats(db, st.session_state.user_email) + ) + st.markdown(f" #### Remaining Chats: `{remaining_chats}/{max_chat_histories}`") + if st.button("➕ New chat"): + database(db, previous_key=st.session_state.key) + create_message() + +def display_previous_chats(db): + """Displays previous chat records. + + Retrieves and displays a list of previous chat records for the user. + Allows the user to select a chat to view. + + Args: + db: Deta Base instance. + + Returns: + None + """ + previous_chats = get_previous_chats(db, st.session_state.user_email) + reversed_chats = reversed(previous_chats) + + for chat in reversed_chats: + if st.button(chat["title"], key=chat["key"]): + update_session_state(db, chat) + +def create_message(): + """Creates a default assistant message and initializes a session key.""" + + st.session_state["messages"] = [ + {"role": "assistant", "content": "How may I help you?"} + ] + st.session_state["key"] = "key" + return + +def update_session_state(db, chat): + """Updates the session state with selected chat information. + + Args: + db: Deta Base instance. + chat (dict): Selected chat information. + + Returns: + None + """ + previous_chat = st.session_state["messages"] + previous_key = st.session_state["key"] + st.session_state["messages"] = chat["chat"] + st.session_state["key"] = chat["key"] + database(db, previous_key, previous_chat) + + diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..c005289 --- /dev/null +++ b/src/model.py @@ -0,0 +1,23 @@ +from langchain.llms import HuggingFaceHub +import sys +import os + + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from src.auth import * +from src.constant import * + +def create_huggingface_hub(): + """Creates an instance of Hugging Face Hub with specified configurations. + + Returns: + HuggingFaceHub: Instance of Hugging Face Hub. + """ + return HuggingFaceHub( + huggingfacehub_api_token=HUGGINGFACE_API_TOKEN, + repo_id=REPO_ID, + model_kwargs={"temperature": 0.2, "max_new_tokens": 180}, + ) + + diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..1d7b22f --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,51 @@ +import pytest +from unittest.mock import AsyncMock, patch +from httpx_oauth.clients.google import GoogleOAuth2 +from src.constant import * +from src.auth import ( + get_authorization_url, + get_access_token, + get_email, + get_login_str, +) + +@pytest.mark.asyncio +async def test_get_authorization_url(): + client = GoogleOAuth2("client_id", "client_secret") + redirect_uri = "http://example.com/callback" + with patch.object(client, "get_authorization_url", new=AsyncMock()) as mock_method: + await get_authorization_url(client, redirect_uri) + mock_method.assert_called_once_with( + redirect_uri, scope=["profile", "email"] + ) + +@pytest.mark.asyncio +async def test_get_access_token(): + client = GoogleOAuth2("client_id", "client_secret") + redirect_uri = "http://example.com/callback" + code = "code" + with patch.object(client, "get_access_token", new=AsyncMock()) as mock_method: + await get_access_token(client, redirect_uri, code) + mock_method.assert_called_once_with(code, redirect_uri) + +@pytest.mark.asyncio +async def test_get_email(): + client = GoogleOAuth2("client_id", "client_secret") + token = "token" + with patch.object( + client, + "get_id_email", + new=AsyncMock(return_value=("user_id", "user_email")), + ) as mock_method: + user_id, user_email = await get_email(client, token) + mock_method.assert_called_once_with(token) + assert user_id == "user_id" + assert user_email == "user_email" + +def test_get_login_str(): + with patch("asyncio.run") as mock_run: + mock_run.return_value = "authorization_url" + result = get_login_str() + mock_run.assert_called_once() + assert '' in result + assert "Login with Google" in result diff --git a/tests/test_backend.py b/tests/test_backend.py new file mode 100644 index 0000000..231b49a --- /dev/null +++ b/tests/test_backend.py @@ -0,0 +1,38 @@ +import pytest +from unittest.mock import patch, MagicMock +import sys, os +from src.backend import * +from src.constant import * + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +@pytest.fixture +def mock_open(): + with patch('builtins.open', new_callable=MagicMock) as mock_open: + yield mock_open + +@pytest.fixture +def mock_markdown(): + with patch('streamlit.markdown') as mock_markdown: + yield mock_markdown + +@pytest.fixture +def mock_set_page_config(): + with patch('streamlit.set_page_config') as mock_set_page_config: + yield mock_set_page_config + +@pytest.fixture +def mock_oauth2_component(): + with patch('streamlit_oauth.OAuth2Component') as mock_oauth2_component: + yield mock_oauth2_component + +def test_configure_page_styles(mock_open, mock_markdown, mock_set_page_config): + mock_open.return_value.__enter__.return_value.read.return_value = 'test' + configure_page_styles('test_file') + mock_set_page_config.assert_called_once_with(page_title="Querypls", page_icon="💬", layout="wide") + mock_markdown.assert_called() + mock_open.assert_called_once_with('test_file') + +def test_hide_main_menu_and_footer(mock_markdown): + hide_main_menu_and_footer() + mock_markdown.assert_called_once_with("""""", unsafe_allow_html=True) diff --git a/tests/test_frontend.py b/tests/test_frontend.py new file mode 100644 index 0000000..99b6dba --- /dev/null +++ b/tests/test_frontend.py @@ -0,0 +1,71 @@ +import pytest +from unittest.mock import patch, MagicMock +import streamlit as st +from src.frontend import ( + display_logo_and_heading, + display_welcome_message, + handle_new_chat, + display_previous_chats, + create_message, + update_session_state, +) + +@pytest.fixture +def mock_st(): + return MagicMock() + +@pytest.fixture +def mock_db(): + return MagicMock() + +class MockSessionState: + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + setattr(self, key, value) + +def initialize_session_state(messages=None, key=None, user_email=None): + st.session_state = MockSessionState() + st.session_state.messages = messages or [] + st.session_state.key = key + st.session_state.user_email = user_email + +def test_display_logo_and_heading(mock_st): + with patch.object(st, "image") as mock_image: + display_logo_and_heading() + mock_image.assert_called_once_with("static/image/logo.png") + +def test_display_welcome_message(mock_st): + with patch.object(st, "markdown") as mock_markdown: + with patch.object(st, "session_state", MockSessionState()): + initialize_session_state(messages=[{"role": "assistant", "content": "How may I help you?"}]) + display_welcome_message() + mock_markdown.assert_called_once_with("#### Welcome to \n ## 🛢💬Querypls - Prompt to SQL") + +def test_handle_new_chat(mock_db, mock_st): + with patch("src.frontend.get_previous_chats") as mock_get_previous_chats: + mock_get_previous_chats.return_value = [] + with patch.object(st, "markdown") as mock_markdown, patch.object(st, "button") as mock_button: + with patch.object(st, "session_state", MockSessionState()): + initialize_session_state(messages=[], user_email="test@example.com") + handle_new_chat(mock_db, max_chat_histories=5) + mock_markdown.assert_called_once_with(" #### Remaining Chats: `5/5`") + mock_button.assert_called_once_with("➕ New chat") + +def test_create_message(): + with patch.object(st, "session_state", MockSessionState()): + initialize_session_state(messages=[], key=None) + create_message() + assert st.session_state.messages == [{"role": "assistant", "content": "How may I help you?"}] + assert st.session_state.key == "key" + +def test_update_session_state(mock_db): + chat = {"chat": [{"role": "user", "content": "Hello"}], "key": "new_key"} + with patch.object(st, "session_state", MockSessionState()): + initialize_session_state(messages=[{"role": "assistant", "content": "How may I help you?"}], key="old_key") + with patch("src.frontend.database") as mock_database: + update_session_state(mock_db, chat) + mock_database.assert_called_once_with(mock_db, "old_key", [{"role": "assistant", "content": "How may I help you?"}]) + assert st.session_state.messages == chat["chat"] + assert st.session_state.key == chat["key"]