Skip to content

Commit

Permalink
Refactored CodeBase and Fixed Linting issues
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Jul 28, 2024
1 parent fb65a6f commit 9f95b44
Show file tree
Hide file tree
Showing 14 changed files with 326 additions and 101 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file


> [!Note]
> Querypls, while powered by a 7B model of Satablility AI LLM Model, is currently limited in providing optimal responses for complex queries involving PLSQL or intricate scenarios with multiple table joins.
> Querypls, while powered by a 7B model of Satablility AI LLM Model, is currently limited in providing optimal responses for simple queries.

---

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.black]
line-length = 79
125 changes: 116 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,118 @@
streamlit==1.28.0
aiohttp==3.9.5
aiosignal==1.3.1
altair==5.3.0
annotated-types==0.7.0
anyio==3.7.1
async-timeout==4.0.3
asyncio==3.4.3
huggingface_hub==0.15.1
httpx_oauth==0.13.0
streamlit==1.28.0
langchain==0.0.336
attrs==23.2.0
black==24.4.2
blinker==1.8.2
cachetools==5.4.0
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
dataclasses-json==0.6.7
deta==1.2.0
exceptiongroup==1.2.2
filelock==3.15.4
frozenlist==1.4.1
fsspec==2024.6.1
gitdb==4.0.11
GitPython==3.1.43
greenlet==3.0.3
h11==0.14.0
httpcore==0.17.3
httpx==0.24.1
httpx-oauth==0.13.0
huggingface-hub==0.24.2
idna==3.7
importlib-metadata==6.11.0
iniconfig==2.0.0
Jinja2==3.1.4
joblib==1.4.2
jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
langchain==0.2.11
langchain-community==0.2.10
langchain-core==0.2.24
langchain-huggingface==0.0.3
langchain-text-splitters==0.2.2
langsmith==0.1.93
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.21.3
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
mypy-extensions==1.0.0
networkx==3.3
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.82
nvidia-nvtx-cu12==12.1.105
orjson==3.10.6
packaging==23.2
pandas==2.2.2
pathspec==0.12.1
pillow==10.4.0
platformdirs==4.2.2
pluggy==1.5.0
protobuf==4.25.4
pyarrow==17.0.0
pydantic==2.8.2
pydantic_core==2.20.1
pydeck==0.9.1
Pygments==2.18.0
pytest==8.3.2
python-dateutil==2.9.0.post0
python-dotenv==1.0.0
black
streamlit_oauth==0.1.5
deta==1.2.
pytest
pytz==2024.1
PyYAML==6.0.1
referencing==0.35.1
regex==2024.7.24
requests==2.32.3
rich==13.7.1
rpds-py==0.19.1
safetensors==0.4.3
scikit-learn==1.5.1
scipy==1.14.0
sentence-transformers==3.0.1
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
SQLAlchemy==2.0.31
streamlit==1.28.0
streamlit-oauth==0.1.5
sympy==1.13.1
tenacity==8.5.0
threadpoolctl==3.5.0
tokenizers==0.19.1
toml==0.10.2
tomli==2.0.1
toolz==0.12.1
torch==2.4.0
tornado==6.4.1
tqdm==4.66.4
transformers==4.43.3
triton==3.0.0
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.1
tzlocal==5.2
urllib3==2.2.2
validators==0.33.0
watchdog==4.0.1
yarl==1.9.4
zipp==3.19.2
80 changes: 66 additions & 14 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
import streamlit as st
from deta import Deta
import sys
import os
from backend import configure_page_styles, create_oauth2_component, display_github_badge, handle_google_login_if_needed, hide_main_menu_and_footer
from frontend import create_message, display_logo_and_heading, display_previous_chats, display_welcome_message, handle_new_chat
from backend import (
configure_page_styles,
create_oauth2_component,
display_github_badge,
handle_google_login_if_needed,
hide_main_menu_and_footer,
)
from frontend import (
create_message,
display_logo_and_heading,
display_previous_chats,
display_welcome_message,
handle_new_chat,
)
from model import create_huggingface_hub


Expand All @@ -14,9 +26,10 @@
from src.auth import *
from src.constant import *


def main():
"""Main function to configure and run the Querypls application."""
configure_page_styles('static/css/styles.css')
configure_page_styles("static/css/styles.css")
deta = Deta(DETA_PROJECT_KEY)
if "model" not in st.session_state:
llm = create_huggingface_hub()
Expand Down Expand Up @@ -47,7 +60,25 @@ def main():
"Connect with Google",
REDIRECT_URI,
SCOPE,
icon="data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink' viewBox='0 0 48 48'%3E%3Cdefs%3E%3Cpath id='a' d='M44.5 20H24v8.5h11.8C34.7 33.9 30.1 37 24 37c-7.2 0-13-5.8-13-13s5.8-13 13-13c3.1 0 5.9 1.1 8.1 2.9l6.4-6.4C34.6 4.1 29.6 2 24 2 11.8 2 2 11.8 2 24s9.8 22 22 22c11 0 21-8 21-22 0-1.3-.2-2.7-.5-4z'/%3E%3C/defs%3E%3CclipPath id='b'%3E%3Cuse xlink:href='%23a' overflow='visible'/%3E%3C/clipPath%3E%3Cpath clip-path='url(%23b)' fill='%23FBBC05' d='M0 37V11l17 13z'/%3E%3Cpath clip-path='url(%23b)' fill='%23EA4335' d='M0 11l17 13 7-6.1L48 14V0H0z'/%3E%3Cpath clip-path='url(%23b)' fill='%2334A853' d='M0 37l30-23 7.9 1L48 0v48H0z'/%3E%3Cpath clip-path='url(%23b)' fill='%234285F4' d='M48 48L17 24l-4-3 35-10z'/%3E%3C/svg%3E",
icon="data:image/svg+xml;charset=utf-8,%3Csvg \
xmlns='http://www.w3.org/2000/svg' \
xmlns:xlink='http://www.w3.org/1999/xlink' \
viewBox='0 0 48 48'%3E%3Cdefs%3E%3Cpath id='a' \
d='M44.5 20H24v8.5h11.8C34.7 33.9 30.1 37 24 37c-7.2 \
0-13-5.8-13-13s5.8-13 13-13c3.1 0 5.9 1.1 8.1 \
2.9l6.4-6.4C34.6 4.1 29.6 2 24 2 11.8 2 2 11.8 2 \
24s9.8 22 22 22c11 0 21-8 21-22 \
0-1.3-.2-2.7-.5-4z'/%3E%3C/defs%3E%3CclipPath \
id='b'%3E%3Cuse xlink:href='%23a' \
overflow='visible'/%3E%3C/clipPath%3E%3Cpath \
clip-path='url(%23b)' fill='%23FBBC05' \
d='M0 37V11l17 13z'/%3E%3Cpath clip-path='url(%23b)' \
fill='%23EA4335' d='M0 11l17 13 7-6.1L48 \
14V0H0z'/%3E%3Cpath clip-path='url(%23b)' \
fill='%2334A853' d='M0 37l30-23 7.9 1L48 \
0v48H0z'/%3E%3Cpath clip-path='url(%23b)' \
fill='%234285F4' d='M48 48L17 24l-4-3 \
35-10z'/%3E%3C/svg%3E",
use_container_width=True,
)
handle_google_login_if_needed(result)
Expand All @@ -74,21 +105,42 @@ def main():
st.markdown(message["content"], unsafe_allow_html=True)

if prompt := st.chat_input(disabled=(st.session_state.code is False)):
st.session_state.messages.append({"role": "user", "content": prompt})
st.session_state.messages.append(
{"role": "user", "content": prompt}
)
with st.chat_message("user"):
st.write(prompt)

prompting = PromptTemplate(template=TEMPLATE, input_variables=["question"])
if "model" in st.session_state:
llm_chain = LLMChain(prompt=prompting, llm=st.session_state.model)
prompt_template = PromptTemplate(
template=TEMPLATE, input_variables=["question"]
)

if "model" in st.session_state:
llm_chain = (
prompt_template
| st.session_state.model
| StrOutputParser()
)
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Generating..."):
response = llm_chain.run(prompt)
st.markdown(response, unsafe_allow_html=True)
message = {"role": "assistant", "content": response}
st.session_state.messages.append(message)
response = llm_chain.invoke(prompt)
import re

code_block_match = re.search(
r"```sql(.*?)```", response, re.DOTALL
)
if code_block_match:
code_block = code_block_match.group(1)
st.markdown(
f"```sql\n{code_block}\n```",
unsafe_allow_html=True,
)
message = {
"role": "assistant",
"content": f"```sql\n{code_block}\n```",
}
st.session_state.messages.append(message)


if __name__ == "__main__":
Expand Down
8 changes: 6 additions & 2 deletions src/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ async def get_email(client: GoogleOAuth2, token: str):
user_id, user_email = await client.get_id_email(token)
return user_id, user_email


def get_login_str():
client: GoogleOAuth2 = GoogleOAuth2(CLIENT_ID, CLIENT_SECRET)
authorization_url = asyncio.run(get_authorization_url(client, REDIRECT_URI))
return f"""<a href="{authorization_url}" target="_self"><button class="button-51" role="button">Login with Google</button></a>"""
authorization_url = asyncio.run(
get_authorization_url(client, REDIRECT_URI)
)
return f"""<a href="{authorization_url}" target="_self">\
<button class="button-51" role="button">Login with Google</button></a>"""
39 changes: 26 additions & 13 deletions src/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from src.auth import *
from src.constant import *


def configure_page_styles(file_name):
"""Configures Streamlit page styles for Querypls.
Expand All @@ -19,22 +20,30 @@ def configure_page_styles(file_name):
Note:
Ensure 'static/css/styles.css' exists with desired styles.
"""
st.set_page_config(page_title="Querypls", page_icon="💬",layout="wide",)
st.set_page_config(
page_title="Querypls",
page_icon="💬",
layout="wide",
)
with open(file_name) as f:
st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)
st.markdown(
"<style>{}</style>".format(f.read()), unsafe_allow_html=True
)

hide_streamlit_style = (
"""<style>#MainMenu {visibility: hidden;}footer {visibility: hidden;}</style>"""
)
hide_streamlit_style = """<style>#MainMenu {visibility: hidden;}\
footer {visibility: hidden;}</style>"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)


def hide_main_menu_and_footer():
"""Hides the Streamlit main menu and footer for a cleaner interface."""
st.markdown(
"""<style>#MainMenu {visibility: hidden;}footer {visibility: hidden;}</style>""",
"""<style>#MainMenu {visibility: hidden;}\
footer {visibility: hidden;}</style>""",
unsafe_allow_html=True,
)



def handle_google_login_if_needed(result):
"""Handles Google login if it has not been run yet.
Expand All @@ -56,19 +65,24 @@ def handle_google_login_if_needed(result):
st.session_state.user_email = email
st.session_state.code = True
return
except:
except Exception:
st.warning(
"Seems like there is a network issue. Please check your internet connection."
"Seems like there is a network issue. \
Please check your internet connection."
)
sys.exit()



def display_github_badge():
"""Displays a GitHub badge with a link to the Querypls repository."""
st.markdown(
"""<a href='https://github.com/samadpls/Querypls'><img src='https://img.shields.io/github/stars/samadpls/querypls?color=red&label=star%20me&logoColor=red&style=social'></a>""",
"""<a href='https://github.com/samadpls/Querypls'>\
<img src='https://img.shields.io/github/stars/samadpls/querypls\
?color=red&label=star%20me&logoColor=red&style=social'\
></a>""",
unsafe_allow_html=True,
)


def create_oauth2_component():
return OAuth2Component(
Expand All @@ -79,4 +93,3 @@ def create_oauth2_component():
REFRESH_TOKEN_URL,
REVOKE_TOKEN_URL,
)

5 changes: 4 additions & 1 deletion src/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import streamlit as st


def get_previous_chats(db, user_email):
"""Fetches previous chat records for a user from the database.
Expand Down Expand Up @@ -36,7 +37,9 @@ def database(db, previous_key="key", previous_chat=None, max_chat_histories=5):
and previous_key != "key"
):
new_messages = [
message for message in previous_chat if message not in existing_chat["chat"]
message
for message in previous_chat
if message not in existing_chat["chat"]
]
existing_chat["chat"].extend(new_messages)
db.update({"chat": existing_chat["chat"]}, key=previous_key)
Expand Down
Loading

0 comments on commit 9f95b44

Please sign in to comment.