From 445234b34ee582015538f1ce1e1da0829daa00c6 Mon Sep 17 00:00:00 2001 From: Abdul Samad Siddiqui Date: Mon, 12 Feb 2024 20:06:09 +0000 Subject: [PATCH] Revert "Refactor into multiple file (#10)" This reverts commit abfb7f70a8a03a071b56950989c24070e2cb1edd. --- .../workflows/{pytest.yml => unittests.yml} | 6 +- requirements.txt | 3 +- src/__init__.py | 0 src/app.py | 237 +++++++++++++++++- src/backend.py | 82 ------ src/database.py | 61 ----- src/frontend.py | 78 ------ src/model.py | 23 -- test/test_auth.py | 56 +++++ tests/__init__.py | 0 tests/test_auth.py | 51 ---- tests/test_backend.py | 38 --- tests/test_frontend.py | 71 ------ 13 files changed, 293 insertions(+), 413 deletions(-) rename .github/workflows/{pytest.yml => unittests.yml} (84%) delete mode 100644 src/__init__.py delete mode 100644 src/backend.py delete mode 100644 src/database.py delete mode 100644 src/frontend.py delete mode 100644 src/model.py create mode 100644 test/test_auth.py delete mode 100644 tests/__init__.py delete mode 100644 tests/test_auth.py delete mode 100644 tests/test_backend.py delete mode 100644 tests/test_frontend.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/unittests.yml similarity index 84% rename from .github/workflows/pytest.yml rename to .github/workflows/unittests.yml index a185623..4891457 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/unittests.yml @@ -1,4 +1,4 @@ -name: Run Pytest +name: Run Unittests on: [push, pull_request] @@ -20,6 +20,6 @@ jobs: python3 -m pip install --upgrade pip pip3 install -r requirements.txt - - name: Run Pytest + - name: Run unittests run: | - pytest + python3 -m unittest diff --git a/requirements.txt b/requirements.txt index 7c5fba4..1692f88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,4 @@ langchain==0.0.336 python-dotenv==1.0.0 black streamlit_oauth==0.1.5 -deta==1.2. -pytest \ No newline at end of file +deta==1.2. \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/app.py b/src/app.py index 8fc361a..e5d5194 100644 --- a/src/app.py +++ b/src/app.py @@ -1,19 +1,248 @@ from langchain.chains import LLMChain +from langchain.llms import HuggingFaceHub from langchain.prompts import PromptTemplate import streamlit as st +from streamlit_oauth import OAuth2Component from deta import Deta import sys +import time import os -from backend import configure_page_styles, create_oauth2_component, display_github_badge, handle_google_login_if_needed, hide_main_menu_and_footer -from frontend import create_message, display_logo_and_heading, display_previous_chats, display_welcome_message, handle_new_chat -from model import create_huggingface_hub - +import json +import base64 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.auth import * from src.constant import * + +def configure_page_styles(file_name): + """Configures Streamlit page styles for Querypls. + + Sets page title, icon, and applies custom CSS styles. + Hides Streamlit main menu and footer for a cleaner interface. + + Note: + Ensure 'static/css/styles.css' exists with desired styles. + """ + st.set_page_config(page_title="Querypls", page_icon="💬",layout="wide",) + with open(file_name) as f: + st.markdown(''.format(f.read()), unsafe_allow_html=True) + + hide_streamlit_style = ( + """""" + ) + st.markdown(hide_streamlit_style, unsafe_allow_html=True) + + +def display_logo_and_heading(): + """Displays the Querypls logo.""" + st.image("static/image/logo.png") + + +def get_previous_chats(db, user_email): + """Fetches previous chat records for a user from the database. + + Args: + db: Deta Base instance. + user_email (str): User's email address. + + Returns: + list: List of previous chat records. + """ + return db.fetch({"email": user_email}).items + + +def display_previous_chats(db): + """Displays previous chat records. + + Retrieves and displays a list of previous chat records for the user. + Allows the user to select a chat to view. + + Args: + db: Deta Base instance. + + Returns: + None + """ + previous_chats = get_previous_chats(db, st.session_state.user_email) + reversed_chats = reversed(previous_chats) + + for chat in reversed_chats: + if st.button(chat["title"], key=chat["key"]): + update_session_state(db, chat) + + +def update_session_state(db, chat): + """Updates the session state with selected chat information. + + Args: + db: Deta Base instance. + chat (dict): Selected chat information. + + Returns: + None + """ + previous_chat = st.session_state["messages"] + previous_key = st.session_state["key"] + st.session_state["messages"] = chat["chat"] + st.session_state["key"] = chat["key"] + database(db, previous_key, previous_chat) + + +def database(db, previous_key="key", previous_chat=None, max_chat_histories=5): + """Manages user chat history in the database. + + Updates, adds, or removes chat history based on user interaction. + + Args: + db: Deta Base instance. + previous_key (str): Key for the previous chat in the database. + previous_chat (list, optional): Previous chat messages. + max_chat_histories (int, optional): Maximum number of chat histories to retain. + + Returns: + None + """ + user_email = st.session_state.user_email + previous_chats = get_previous_chats(db, user_email) + existing_chat = db.get(previous_key) if previous_key != "key" else None + if ( + previous_chat is not None + and existing_chat is not None + and previous_key != "key" + ): + new_messages = [ + message for message in previous_chat if message not in existing_chat["chat"] + ] + existing_chat["chat"].extend(new_messages) + db.update({"chat": existing_chat["chat"]}, key=previous_key) + return + previous_chat = ( + st.session_state.messages if previous_chat is None else previous_chat + ) + if len(previous_chat) > 1 and previous_key == "key": + title = previous_chat[1]["content"] + db.put( + { + "email": user_email, + "chat": previous_chat, + "title": title[:25] + "....." if len(title) > 25 else title, + } + ) + + if len(previous_chats) >= max_chat_histories: + db.delete(previous_chats[0]["key"]) + st.warning( + f"Chat '{previous_chats[0]['title']}' has been removed as you reached the limit of {max_chat_histories} chat histories." + ) + + +def create_message(): + """Creates a default assistant message and initializes a session key.""" + + st.session_state["messages"] = [ + {"role": "assistant", "content": "How may I help you?"} + ] + st.session_state["key"] = "key" + return + + +def handle_new_chat(db, max_chat_histories=5): + """Handles the initiation of a new chat session. + + Displays the remaining chat history count and provides a button to start a new chat. + + Args: + db: Deta Base instance. + max_chat_histories (int, optional): Maximum number of chat histories to retain. + + Returns: + None + """ + remaining_chats = max_chat_histories - len( + get_previous_chats(db, st.session_state.user_email) + ) + st.markdown(f" #### Remaining Chats: `{remaining_chats}/{max_chat_histories}`") + if st.button("➕ New chat"): + database(db, previous_key=st.session_state.key) + create_message() + + +def create_huggingface_hub(): + """Creates an instance of Hugging Face Hub with specified configurations. + + Returns: + HuggingFaceHub: Instance of Hugging Face Hub. + """ + return HuggingFaceHub( + huggingfacehub_api_token=HUGGINGFACE_API_TOKEN, + repo_id=REPO_ID, + model_kwargs={"temperature": 0.2, "max_new_tokens": 180}, + ) + + +def hide_main_menu_and_footer(): + """Hides the Streamlit main menu and footer for a cleaner interface.""" + st.markdown( + """""", + unsafe_allow_html=True, + ) + + +def display_github_badge(): + """Displays a GitHub badge with a link to the Querypls repository.""" + st.markdown( + """""", + unsafe_allow_html=True, + ) + + +def handle_google_login_if_needed(result): + """Handles Google login if it has not been run yet. + + Args: + result (str): Authorization code received from Google. + + Returns: + None + """ + try: + if result and "token" in result: + st.session_state.token = result.get("token") + token = st.session_state["token"] + id_token = token["id_token"] + payload = id_token.split(".")[1] + payload += "=" * (-len(payload) % 4) + payload = json.loads(base64.b64decode(payload)) + email = payload["email"] + st.session_state.user_email = email + st.session_state.code = True + return + except: + st.warning( + "Seems like there is a network issue. Please check your internet connection." + ) + sys.exit() + + +def display_welcome_message(): + """Displays a welcome message based on user chat history.""" + no_chat_history = len(st.session_state.messages) == 1 + if no_chat_history: + st.markdown(f"#### Welcome to \n ## 🛢💬Querypls - Prompt to SQL") + + +def create_oauth2_component(): + return OAuth2Component( + CLIENT_ID, + CLIENT_SECRET, + AUTHORIZE_URL, + TOKEN_URL, + REFRESH_TOKEN_URL, + REVOKE_TOKEN_URL, + ) + def main(): """Main function to configure and run the Querypls application.""" configure_page_styles('static/css/styles.css') diff --git a/src/backend.py b/src/backend.py deleted file mode 100644 index d820c42..0000000 --- a/src/backend.py +++ /dev/null @@ -1,82 +0,0 @@ -import streamlit as st -from streamlit_oauth import OAuth2Component -import sys -import os -import json -import base64 - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from src.auth import * -from src.constant import * - -def configure_page_styles(file_name): - """Configures Streamlit page styles for Querypls. - - Sets page title, icon, and applies custom CSS styles. - Hides Streamlit main menu and footer for a cleaner interface. - - Note: - Ensure 'static/css/styles.css' exists with desired styles. - """ - st.set_page_config(page_title="Querypls", page_icon="💬",layout="wide",) - with open(file_name) as f: - st.markdown(''.format(f.read()), unsafe_allow_html=True) - - hide_streamlit_style = ( - """""" - ) - st.markdown(hide_streamlit_style, unsafe_allow_html=True) - -def hide_main_menu_and_footer(): - """Hides the Streamlit main menu and footer for a cleaner interface.""" - st.markdown( - """""", - unsafe_allow_html=True, - ) - -def handle_google_login_if_needed(result): - """Handles Google login if it has not been run yet. - - Args: - result (str): Authorization code received from Google. - - Returns: - None - """ - try: - if result and "token" in result: - st.session_state.token = result.get("token") - token = st.session_state["token"] - id_token = token["id_token"] - payload = id_token.split(".")[1] - payload += "=" * (-len(payload) % 4) - payload = json.loads(base64.b64decode(payload)) - email = payload["email"] - st.session_state.user_email = email - st.session_state.code = True - return - except: - st.warning( - "Seems like there is a network issue. Please check your internet connection." - ) - sys.exit() - -def display_github_badge(): - """Displays a GitHub badge with a link to the Querypls repository.""" - st.markdown( - """""", - unsafe_allow_html=True, - ) - - -def create_oauth2_component(): - return OAuth2Component( - CLIENT_ID, - CLIENT_SECRET, - AUTHORIZE_URL, - TOKEN_URL, - REFRESH_TOKEN_URL, - REVOKE_TOKEN_URL, - ) - diff --git a/src/database.py b/src/database.py deleted file mode 100644 index 8c110de..0000000 --- a/src/database.py +++ /dev/null @@ -1,61 +0,0 @@ -import streamlit as st - -def get_previous_chats(db, user_email): - """Fetches previous chat records for a user from the database. - - Args: - db: Deta Base instance. - user_email (str): User's email address. - - Returns: - list: List of previous chat records. - """ - return db.fetch({"email": user_email}).items - - -def database(db, previous_key="key", previous_chat=None, max_chat_histories=5): - """Manages user chat history in the database. - - Updates, adds, or removes chat history based on user interaction. - - Args: - db: Deta Base instance. - previous_key (str): Key for the previous chat in the database. - previous_chat (list, optional): Previous chat messages. - max_chat_histories (int, optional): Maximum number of chat histories to retain. - - Returns: - None - """ - user_email = st.session_state.user_email - previous_chats = get_previous_chats(db, user_email) - existing_chat = db.get(previous_key) if previous_key != "key" else None - if ( - previous_chat is not None - and existing_chat is not None - and previous_key != "key" - ): - new_messages = [ - message for message in previous_chat if message not in existing_chat["chat"] - ] - existing_chat["chat"].extend(new_messages) - db.update({"chat": existing_chat["chat"]}, key=previous_key) - return - previous_chat = ( - st.session_state.messages if previous_chat is None else previous_chat - ) - if len(previous_chat) > 1 and previous_key == "key": - title = previous_chat[1]["content"] - db.put( - { - "email": user_email, - "chat": previous_chat, - "title": title[:25] + "....." if len(title) > 25 else title, - } - ) - - if len(previous_chats) >= max_chat_histories: - db.delete(previous_chats[0]["key"]) - st.warning( - f"Chat '{previous_chats[0]['title']}' has been removed as you reached the limit of {max_chat_histories} chat histories." - ) diff --git a/src/frontend.py b/src/frontend.py deleted file mode 100644 index 337d08a..0000000 --- a/src/frontend.py +++ /dev/null @@ -1,78 +0,0 @@ -import streamlit as st -from src.database import database, get_previous_chats - -def display_logo_and_heading(): - """Displays the Querypls logo.""" - st.image("static/image/logo.png") - -def display_welcome_message(): - """Displays a welcome message based on user chat history.""" - no_chat_history = len(st.session_state.messages) == 1 - if no_chat_history: - st.markdown(f"#### Welcome to \n ## 🛢💬Querypls - Prompt to SQL") - -def handle_new_chat(db, max_chat_histories=5): - """Handles the initiation of a new chat session. - - Displays the remaining chat history count and provides a button to start a new chat. - - Args: - db: Deta Base instance. - max_chat_histories (int, optional): Maximum number of chat histories to retain. - - Returns: - None - """ - remaining_chats = max_chat_histories - len( - get_previous_chats(db, st.session_state.user_email) - ) - st.markdown(f" #### Remaining Chats: `{remaining_chats}/{max_chat_histories}`") - if st.button("➕ New chat"): - database(db, previous_key=st.session_state.key) - create_message() - -def display_previous_chats(db): - """Displays previous chat records. - - Retrieves and displays a list of previous chat records for the user. - Allows the user to select a chat to view. - - Args: - db: Deta Base instance. - - Returns: - None - """ - previous_chats = get_previous_chats(db, st.session_state.user_email) - reversed_chats = reversed(previous_chats) - - for chat in reversed_chats: - if st.button(chat["title"], key=chat["key"]): - update_session_state(db, chat) - -def create_message(): - """Creates a default assistant message and initializes a session key.""" - - st.session_state["messages"] = [ - {"role": "assistant", "content": "How may I help you?"} - ] - st.session_state["key"] = "key" - return - -def update_session_state(db, chat): - """Updates the session state with selected chat information. - - Args: - db: Deta Base instance. - chat (dict): Selected chat information. - - Returns: - None - """ - previous_chat = st.session_state["messages"] - previous_key = st.session_state["key"] - st.session_state["messages"] = chat["chat"] - st.session_state["key"] = chat["key"] - database(db, previous_key, previous_chat) - - diff --git a/src/model.py b/src/model.py deleted file mode 100644 index c005289..0000000 --- a/src/model.py +++ /dev/null @@ -1,23 +0,0 @@ -from langchain.llms import HuggingFaceHub -import sys -import os - - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from src.auth import * -from src.constant import * - -def create_huggingface_hub(): - """Creates an instance of Hugging Face Hub with specified configurations. - - Returns: - HuggingFaceHub: Instance of Hugging Face Hub. - """ - return HuggingFaceHub( - huggingfacehub_api_token=HUGGINGFACE_API_TOKEN, - repo_id=REPO_ID, - model_kwargs={"temperature": 0.2, "max_new_tokens": 180}, - ) - - diff --git a/test/test_auth.py b/test/test_auth.py new file mode 100644 index 0000000..06ded02 --- /dev/null +++ b/test/test_auth.py @@ -0,0 +1,56 @@ +import unittest +from unittest.mock import AsyncMock, patch +from httpx_oauth.clients.google import GoogleOAuth2 +from src.constant import * +from src.auth import ( + get_authorization_url, + get_access_token, + get_email, + get_login_str, +) + + +class TestGoogleOAuth2Methods(unittest.IsolatedAsyncioTestCase): + async def test_get_authorization_url(self): + client = GoogleOAuth2("client_id", "client_secret") + redirect_uri = "http://example.com/callback" + with patch.object( + client, "get_authorization_url", new=AsyncMock() + ) as mock_method: + await get_authorization_url(client, redirect_uri) + mock_method.assert_called_once_with( + redirect_uri, scope=["profile", "email"] + ) + + async def test_get_access_token(self): + client = GoogleOAuth2("client_id", "client_secret") + redirect_uri = "http://example.com/callback" + code = "code" + with patch.object(client, "get_access_token", new=AsyncMock()) as mock_method: + await get_access_token(client, redirect_uri, code) + mock_method.assert_called_once_with(code, redirect_uri) + + async def test_get_email(self): + client = GoogleOAuth2("client_id", "client_secret") + token = "token" + with patch.object( + client, + "get_id_email", + new=AsyncMock(return_value=("user_id", "user_email")), + ) as mock_method: + user_id, user_email = await get_email(client, token) + mock_method.assert_called_once_with(token) + self.assertEqual(user_id, "user_id") + self.assertEqual(user_email, "user_email") + + def test_get_login_str(self): + with patch("asyncio.run") as mock_run: + mock_run.return_value = "authorization_url" + result = get_login_str() + mock_run.assert_called_once() + self.assertIn('', result) + self.assertIn("Login with Google", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index 1d7b22f..0000000 --- a/tests/test_auth.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch -from httpx_oauth.clients.google import GoogleOAuth2 -from src.constant import * -from src.auth import ( - get_authorization_url, - get_access_token, - get_email, - get_login_str, -) - -@pytest.mark.asyncio -async def test_get_authorization_url(): - client = GoogleOAuth2("client_id", "client_secret") - redirect_uri = "http://example.com/callback" - with patch.object(client, "get_authorization_url", new=AsyncMock()) as mock_method: - await get_authorization_url(client, redirect_uri) - mock_method.assert_called_once_with( - redirect_uri, scope=["profile", "email"] - ) - -@pytest.mark.asyncio -async def test_get_access_token(): - client = GoogleOAuth2("client_id", "client_secret") - redirect_uri = "http://example.com/callback" - code = "code" - with patch.object(client, "get_access_token", new=AsyncMock()) as mock_method: - await get_access_token(client, redirect_uri, code) - mock_method.assert_called_once_with(code, redirect_uri) - -@pytest.mark.asyncio -async def test_get_email(): - client = GoogleOAuth2("client_id", "client_secret") - token = "token" - with patch.object( - client, - "get_id_email", - new=AsyncMock(return_value=("user_id", "user_email")), - ) as mock_method: - user_id, user_email = await get_email(client, token) - mock_method.assert_called_once_with(token) - assert user_id == "user_id" - assert user_email == "user_email" - -def test_get_login_str(): - with patch("asyncio.run") as mock_run: - mock_run.return_value = "authorization_url" - result = get_login_str() - mock_run.assert_called_once() - assert '' in result - assert "Login with Google" in result diff --git a/tests/test_backend.py b/tests/test_backend.py deleted file mode 100644 index 231b49a..0000000 --- a/tests/test_backend.py +++ /dev/null @@ -1,38 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import sys, os -from src.backend import * -from src.constant import * - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -@pytest.fixture -def mock_open(): - with patch('builtins.open', new_callable=MagicMock) as mock_open: - yield mock_open - -@pytest.fixture -def mock_markdown(): - with patch('streamlit.markdown') as mock_markdown: - yield mock_markdown - -@pytest.fixture -def mock_set_page_config(): - with patch('streamlit.set_page_config') as mock_set_page_config: - yield mock_set_page_config - -@pytest.fixture -def mock_oauth2_component(): - with patch('streamlit_oauth.OAuth2Component') as mock_oauth2_component: - yield mock_oauth2_component - -def test_configure_page_styles(mock_open, mock_markdown, mock_set_page_config): - mock_open.return_value.__enter__.return_value.read.return_value = 'test' - configure_page_styles('test_file') - mock_set_page_config.assert_called_once_with(page_title="Querypls", page_icon="💬", layout="wide") - mock_markdown.assert_called() - mock_open.assert_called_once_with('test_file') - -def test_hide_main_menu_and_footer(mock_markdown): - hide_main_menu_and_footer() - mock_markdown.assert_called_once_with("""""", unsafe_allow_html=True) diff --git a/tests/test_frontend.py b/tests/test_frontend.py deleted file mode 100644 index 99b6dba..0000000 --- a/tests/test_frontend.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import streamlit as st -from src.frontend import ( - display_logo_and_heading, - display_welcome_message, - handle_new_chat, - display_previous_chats, - create_message, - update_session_state, -) - -@pytest.fixture -def mock_st(): - return MagicMock() - -@pytest.fixture -def mock_db(): - return MagicMock() - -class MockSessionState: - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - setattr(self, key, value) - -def initialize_session_state(messages=None, key=None, user_email=None): - st.session_state = MockSessionState() - st.session_state.messages = messages or [] - st.session_state.key = key - st.session_state.user_email = user_email - -def test_display_logo_and_heading(mock_st): - with patch.object(st, "image") as mock_image: - display_logo_and_heading() - mock_image.assert_called_once_with("static/image/logo.png") - -def test_display_welcome_message(mock_st): - with patch.object(st, "markdown") as mock_markdown: - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state(messages=[{"role": "assistant", "content": "How may I help you?"}]) - display_welcome_message() - mock_markdown.assert_called_once_with("#### Welcome to \n ## 🛢💬Querypls - Prompt to SQL") - -def test_handle_new_chat(mock_db, mock_st): - with patch("src.frontend.get_previous_chats") as mock_get_previous_chats: - mock_get_previous_chats.return_value = [] - with patch.object(st, "markdown") as mock_markdown, patch.object(st, "button") as mock_button: - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state(messages=[], user_email="test@example.com") - handle_new_chat(mock_db, max_chat_histories=5) - mock_markdown.assert_called_once_with(" #### Remaining Chats: `5/5`") - mock_button.assert_called_once_with("➕ New chat") - -def test_create_message(): - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state(messages=[], key=None) - create_message() - assert st.session_state.messages == [{"role": "assistant", "content": "How may I help you?"}] - assert st.session_state.key == "key" - -def test_update_session_state(mock_db): - chat = {"chat": [{"role": "user", "content": "Hello"}], "key": "new_key"} - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state(messages=[{"role": "assistant", "content": "How may I help you?"}], key="old_key") - with patch("src.frontend.database") as mock_database: - update_session_state(mock_db, chat) - mock_database.assert_called_once_with(mock_db, "old_key", [{"role": "assistant", "content": "How may I help you?"}]) - assert st.session_state.messages == chat["chat"] - assert st.session_state.key == chat["key"]