diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py new file mode 100644 index 0000000..cc3c457 --- /dev/null +++ b/backend/app/api/deps.py @@ -0,0 +1,56 @@ +from collections.abc import Generator +from typing import Annotated + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt +from pydantic import ValidationError +from sqlmodel import Session + +from app.core import security +from backend.app.core.config import settings +from app.core.db import engine +from app.models import TokenPayload, User + +reusable_oauth2 = OAuth2PasswordBearer( + tokenUrl=f"{settings.API_V1_STR}/login/access-token" +) + + +def get_db() -> Generator[Session, None, None]: + with Session(engine) as session: + yield session + + +SessionDep = Annotated[Session, Depends(get_db)] +TokenDep = Annotated[str, Depends(reusable_oauth2)] + + +def get_current_user(session: SessionDep, token: TokenDep) -> User: + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) + token_data = TokenPayload(**payload) + except (JWTError, ValidationError): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Could not validate credentials", + ) + user = session.get(User, token_data.sub) + if not user: + raise HTTPException(status_code=404, detail="User not found") + if not user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") + return user + + +CurrentUser = Annotated[User, Depends(get_current_user)] + + +def get_current_active_superuser(current_user: CurrentUser) -> User: + if not current_user.is_superuser: + raise HTTPException( + status_code=400, detail="The user doesn't have enough privileges" + ) + return current_user diff --git a/backend/app/api/routes/qa.py b/backend/app/api/routes/qa.py index 657cdbb..e52afb5 100644 --- a/backend/app/api/routes/qa.py +++ b/backend/app/api/routes/qa.py @@ -4,7 +4,7 @@ from fastapi import APIRouter -from app.core.config import logger +from backend.app.core.config import logger from operator import itemgetter @@ -19,9 +19,9 @@ from langchain_core.runnables import RunnableParallel from langchain_community.vectorstores.pgvector import PGVector from langchain.memory import ConversationBufferMemory -from app.core.config import settings +from backend.app.core.config import settings from langchain.prompts.prompt import PromptTemplate -from pydantic import BaseModel +from app.schemas.chat_schema import ChatBody router = APIRouter() @@ -32,12 +32,8 @@ chat_config = config.get("CHAT_CONFIG", None) -class ChatRequest(BaseModel): - message: str - - @router.post("/chat") -async def chat_action(request: ChatRequest): +async def chat_action(request: ChatBody): embeddings = OpenAIEmbeddings() diff --git a/backend/app/config.yml b/backend/app/config/chat.yml similarity index 100% rename from backend/app/config.yml rename to backend/app/config/chat.yml diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 26a7f09..f25bc95 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -1,6 +1,21 @@ from pydantic_settings import BaseSettings from typing import List from loguru import logger +from typing import Annotated, Any, Literal +import sys + +from pydantic import ( + AnyUrl, + BeforeValidator, +) + + +def parse_cors(v: Any) -> list[str] | str: + if isinstance(v, str) and not v.startswith("["): + return [i.strip() for i in v.split(",")] + elif isinstance(v, list | str): + return v + raise ValueError(v) class Settings(BaseSettings): @@ -23,6 +38,13 @@ class Settings(BaseSettings): TAVILY_API_KEY: str + FIRST_SUPERUSER: str + FIRST_SUPERUSER_PASSWORD: str + + BACKEND_CORS_ORIGINS: Annotated[list[AnyUrl] | str, BeforeValidator(parse_cors)] = ( + [] + ) + @property def ASYNC_DATABASE_URI(self) -> str: return f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" @@ -35,4 +57,19 @@ class Config: env_file = "../.env" +class LogConfig: + LOGGING_LEVEL = "DEBUG" + LOGGING_FORMAT = "{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}" + + @staticmethod + def configure_logging(): + logger.remove() + + logger.add( + sys.stderr, format=LogConfig.LOGGING_FORMAT, level=LogConfig.LOGGING_LEVEL + ) + + +LogConfig.configure_logging() + settings = Settings() diff --git a/backend/app/core/db.py b/backend/app/core/db.py index 92edba8..8b6ac00 100644 --- a/backend/app/core/db.py +++ b/backend/app/core/db.py @@ -1,6 +1,6 @@ from sqlmodel import SQLModel from sqlalchemy.ext.asyncio import create_async_engine -from app.core.config import settings +from backend.app.core.config import settings from loguru import logger from sqlalchemy.orm import sessionmaker from sqlmodel.ext.asyncio.session import AsyncSession diff --git a/backend/app/core/security.py b/backend/app/core/security.py new file mode 100644 index 0000000..f201921 --- /dev/null +++ b/backend/app/core/security.py @@ -0,0 +1,27 @@ +from datetime import datetime, timedelta +from typing import Any + +from jose import jwt +from passlib.context import CryptContext + +from backend.app.core.config import settings + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +ALGORITHM = "HS256" + + +def create_access_token(subject: str | Any, expires_delta: timedelta) -> str: + expire = datetime.utcnow() + expires_delta + to_encode = {"exp": expire, "sub": str(subject)} + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password: str) -> str: + return pwd_context.hash(password) diff --git a/backend/app/crud/crud.py b/backend/app/crud/crud.py new file mode 100644 index 0000000..405482a --- /dev/null +++ b/backend/app/crud/crud.py @@ -0,0 +1,53 @@ +from typing import Any + +from sqlmodel import Session, select + +from app.core.security import get_password_hash, verify_password +from app.models import Item, ItemCreate, User, UserCreate, UserUpdate + + +def create_user(*, session: Session, user_create: UserCreate) -> User: + db_obj = User.model_validate( + user_create, update={"hashed_password": get_password_hash(user_create.password)} + ) + session.add(db_obj) + session.commit() + session.refresh(db_obj) + return db_obj + + +def update_user(*, session: Session, db_user: User, user_in: UserUpdate) -> Any: + user_data = user_in.model_dump(exclude_unset=True) + extra_data = {} + if "password" in user_data: + password = user_data["password"] + hashed_password = get_password_hash(password) + extra_data["hashed_password"] = hashed_password + db_user.sqlmodel_update(user_data, update=extra_data) + session.add(db_user) + session.commit() + session.refresh(db_user) + return db_user + + +def get_user_by_email(*, session: Session, email: str) -> User | None: + statement = select(User).where(User.email == email) + session_user = session.exec(statement).first() + return session_user + + +def authenticate(*, session: Session, email: str, password: str) -> User | None: + db_user = get_user_by_email(session=session, email=email) + if not db_user: + return None + if not verify_password(password, db_user.hashed_password): + return None + return db_user + + +def create_item(*, session: Session, item_in: ItemCreate, owner_id: int) -> Item: + db_item = Item.model_validate(item_in, update={"owner_id": owner_id}) + session.add(db_item) + session.commit() + session.refresh(db_item) + return db_item diff --git a/backend/app/main.py b/backend/app/main.py index cb67357..c88475a 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -2,7 +2,7 @@ from contextlib import asynccontextmanager from fastapi.middleware.cors import CORSMiddleware from app.api.main import api_router -from app.core.config import settings +from backend.app.core.config import settings from typing import Dict @@ -11,14 +11,17 @@ docs_url=f"{settings.API_V1_STR}/docs", ) - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# Set all CORS enabled origins +if settings.BACKEND_CORS_ORIGINS: + app.add_middleware( + CORSMiddleware, + allow_origins=[ + str(origin).strip("/") for origin in settings.BACKEND_CORS_ORIGINS + ], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) @app.get("/metrics") diff --git a/backend/ingestion/__init__.py b/backend/ingestion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/ingestion/config.yaml b/backend/ingestion/config/ingestion.yml similarity index 100% rename from backend/ingestion/config.yaml rename to backend/ingestion/config/ingestion.yml diff --git a/backend/ingestion/core/config.py b/backend/ingestion/core/config.py index d86b097..f9bba94 100644 --- a/backend/ingestion/core/config.py +++ b/backend/ingestion/core/config.py @@ -17,12 +17,15 @@ class Settings(BaseSettings): OPENAI_API_KEY: str + FIRST_SUPERUSER: str + FIRST_SUPERUSER_PASSWORD: str + @property def ASYNC_DATABASE_URI(self) -> str: return f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" class Config: - env_file = "../.env" + env_file = ".env" class LogConfig: diff --git a/backend/ingestion/helpers/db.py b/backend/ingestion/core/db.py similarity index 54% rename from backend/ingestion/helpers/db.py rename to backend/ingestion/core/db.py index 0976959..502ef01 100644 --- a/backend/ingestion/helpers/db.py +++ b/backend/ingestion/core/db.py @@ -1,10 +1,22 @@ from sqlmodel import SQLModel from sqlalchemy.ext.asyncio import create_async_engine -from app.core.config import settings +from backend.app.core.config import settings import asyncpg import psycopg2 from loguru import logger + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from ingestion.models.user_model import ( + UserCreate, + User, +) + +from backend.ingestion.crud import user_crud + + engine = create_async_engine(str(settings.ASYNC_DATABASE_URI), echo=True) @@ -51,6 +63,24 @@ def create_database(database_name, user, password, host, port): logger.error(f"Error creating database: {e}") +# async def init_db() -> None: +# create_database( +# settings.DB_NAME, +# settings.DB_USER, +# settings.DB_PASS, +# settings.DB_HOST, +# settings.DB_PORT, +# ) +# async with engine.begin() as conn: +# # Use run_sync to execute the create_all method in an asynchronous context +# await conn.run_sync(SQLModel.metadata.create_all) + +# # Your existing database initialization logic here +# # For example, creating extensions or setting up initial data +# await create_extension() +# logger.info("Database initialized and all tables created if they didn't exist.") + + async def init_db() -> None: create_database( settings.DB_NAME, @@ -67,3 +97,32 @@ async def init_db() -> None: # For example, creating extensions or setting up initial data await create_extension() logger.info("Database initialized and all tables created if they didn't exist.") + + +from sqlmodel import Session, create_engine, select + +from ingestion.crud import user_crud +from sqlmodel import Session + +engine = create_engine(str(settings.SYNC_DATABASE_URI)) + + +# make sure all SQLModel models are imported (app.models) before initializing DB +# otherwise, SQLModel might fail to initialize relationships properly +# for more details: https://github.com/tiangolo/full-stack-fastapi-template/issues/28 + + +def create_super_user() -> None: + + with Session(engine) as session: + + user = session.exec( + select(User).where(User.email == settings.FIRST_SUPERUSER) + ).first() + if not user: + user_in = UserCreate( + email=settings.FIRST_SUPERUSER, + password=settings.FIRST_SUPERUSER_PASSWORD, + is_superuser=True, + ) + user = user_crud.create_user(session=session, user_create=user_in) diff --git a/backend/ingestion/crud/user_crud.py b/backend/ingestion/crud/user_crud.py new file mode 100644 index 0000000..95428a1 --- /dev/null +++ b/backend/ingestion/crud/user_crud.py @@ -0,0 +1,16 @@ +from typing import Any + +from sqlmodel import Session, select + +from app.core.security import get_password_hash +from ingestion.models.user_model import User, UserCreate + + +def create_user(*, session: Session, user_create: UserCreate) -> User: + db_obj = User.model_validate( + user_create, update={"hashed_password": get_password_hash(user_create.password)} + ) + session.add(db_obj) + session.commit() + session.refresh(db_obj) + return db_obj diff --git a/backend/ingestion/pdf_ingestion.py b/backend/ingestion/main.py similarity index 89% rename from backend/ingestion/pdf_ingestion.py rename to backend/ingestion/main.py index 3f96aa4..dcc486f 100644 --- a/backend/ingestion/pdf_ingestion.py +++ b/backend/ingestion/main.py @@ -1,9 +1,10 @@ +import asyncio import sys import os # Temporary solution.It is used to predict the centralization of logs in the future sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) - +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import yaml import json @@ -13,18 +14,22 @@ from dotenv import load_dotenv from langchain.vectorstores.pgvector import PGVector from langchain.embeddings import CacheBackedEmbeddings -from ..logging_config import logger +from ingestion.core.config import logger from schemas.ingestion_schema import LOADER_DICT from fastapi.encoders import jsonable_encoder +from sqlmodel import Session -from helpers.embedding_models import get_embedding_model +from utils.embedding_models import get_embedding_model from langchain.text_splitter import TokenTextSplitter +from ingestion.core.db import engine, init_db +from ingestion.core.db import create_super_user load_dotenv() ingestion_config = yaml.load( - open(os.path.join(os.path.dirname(__file__), "config.yaml")), Loader=yaml.FullLoader + open(os.path.join(os.path.dirname(__file__), "config/ingestion.yml")), + Loader=yaml.FullLoader, ) path_input_folder = ingestion_config.get("PATH_RAW_PDF", None) @@ -156,6 +161,11 @@ def _load_docs( # Example usage if __name__ == "__main__": - logger.info("Starting PDF extraction pipeline") - pipeline = PDFExtractionPipeline() - pipeline.run(collection_name) + + # asyncio.run(init_db()) + + create_super_user() + + # logger.info("Starting PDF extraction pipeline") + # pipeline = PDFExtractionPipeline() + # pipeline.run(collection_name) diff --git a/backend/ingestion/models/user_model.py b/backend/ingestion/models/user_model.py new file mode 100644 index 0000000..6b50650 --- /dev/null +++ b/backend/ingestion/models/user_model.py @@ -0,0 +1,113 @@ +from sqlmodel import Field, Relationship, SQLModel + + +# Shared properties +# TODO replace email str with EmailStr when sqlmodel supports it +class UserBase(SQLModel): + email: str = Field(unique=True, index=True) + is_active: bool = True + is_superuser: bool = False + full_name: str | None = None + + +# Properties to receive via API on creation +class UserCreate(UserBase): + password: str + + +# TODO replace email str with EmailStr when sqlmodel supports it +class UserCreateOpen(SQLModel): + email: str + password: str + full_name: str | None = None + + +# Properties to receive via API on update, all are optional +# TODO replace email str with EmailStr when sqlmodel supports it +class UserUpdate(UserBase): + email: str | None = None # type: ignore + password: str | None = None + + +# TODO replace email str with EmailStr when sqlmodel supports it +class UserUpdateMe(SQLModel): + full_name: str | None = None + email: str | None = None + + +class UpdatePassword(SQLModel): + current_password: str + new_password: str + + +# Database model, database table inferred from class name +class User(UserBase, table=True): + id: int | None = Field(default=None, primary_key=True) + hashed_password: str + items: list["Item"] = Relationship(back_populates="owner") + + +# Properties to return via API, id is always required +class UserOut(UserBase): + id: int + + +class UsersOut(SQLModel): + data: list[UserOut] + count: int + + +# Shared properties +class ItemBase(SQLModel): + title: str + description: str | None = None + + +# Properties to receive on item creation +class ItemCreate(ItemBase): + title: str + + +# Properties to receive on item update +class ItemUpdate(ItemBase): + title: str | None = None # type: ignore + + +# Database model, database table inferred from class name +class Item(ItemBase, table=True): + id: int | None = Field(default=None, primary_key=True) + title: str + owner_id: int | None = Field(default=None, foreign_key="user.id", nullable=False) + owner: User | None = Relationship(back_populates="items") + + +# Properties to return via API, id is always required +class ItemOut(ItemBase): + id: int + owner_id: int + + +class ItemsOut(SQLModel): + data: list[ItemOut] + count: int + + +# Generic message +class Message(SQLModel): + message: str + + +# JSON payload containing access token +class Token(SQLModel): + access_token: str + token_type: str = "bearer" + + +# Contents of JWT token +class TokenPayload(SQLModel): + sub: int | None = None + + +class NewPassword(SQLModel): + token: str + new_password: str diff --git a/backend/ingestion/helpers/embedding_models.py b/backend/ingestion/utils/embedding_models.py similarity index 97% rename from backend/ingestion/helpers/embedding_models.py rename to backend/ingestion/utils/embedding_models.py index 2512f11..2d1157b 100644 --- a/backend/ingestion/helpers/embedding_models.py +++ b/backend/ingestion/utils/embedding_models.py @@ -2,7 +2,7 @@ from langchain.embeddings import CacheBackedEmbeddings from langchain_openai import OpenAIEmbeddings -from backend.logging_config import logger +from ingestion.core.db import logger class CacheBackedEmbeddingsExtended(CacheBackedEmbeddings): diff --git a/backend/poetry.lock b/backend/poetry.lock index 9162c29..c6c10b3 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -399,6 +399,24 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "ecdsa" +version = "0.18.0" +description = "ECDSA cryptographic signature library (pure python)" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "ecdsa-0.18.0-py2.py3-none-any.whl", hash = "sha256:80600258e7ed2f16b9aa1d7c295bd70194109ad5a30fdee0eaeefef1d4c559dd"}, + {file = "ecdsa-0.18.0.tar.gz", hash = "sha256:190348041559e21b22a1d65cee485282ca11a6f81d503fddb84d5017e9ed1e49"}, +] + +[package.dependencies] +six = ">=1.9.0" + +[package.extras] +gmpy = ["gmpy"] +gmpy2 = ["gmpy2"] + [[package]] name = "fastapi" version = "0.110.0" @@ -1106,6 +1124,23 @@ files = [ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, ] +[[package]] +name = "passlib" +version = "1.7.4" +description = "comprehensive password hashing framework supporting over 30 schemes" +optional = false +python-versions = "*" +files = [ + {file = "passlib-1.7.4-py2.py3-none-any.whl", hash = "sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1"}, + {file = "passlib-1.7.4.tar.gz", hash = "sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04"}, +] + +[package.extras] +argon2 = ["argon2-cffi (>=18.2.0)"] +bcrypt = ["bcrypt (>=3.1.0)"] +build-docs = ["cloud-sptheme (>=1.10.1)", "sphinx (>=1.6)", "sphinxcontrib-fulltoc (>=1.2.0)"] +totp = ["cryptography"] + [[package]] name = "pgvector" version = "0.2.5" @@ -1200,6 +1235,17 @@ files = [ {file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"}, ] +[[package]] +name = "pyasn1" +version = "0.6.0" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1-0.6.0-py2.py3-none-any.whl", hash = "sha256:cca4bb0f2df5504f02f6f8a775b6e416ff9b0b3b16f7ee80b5a3153d9b804473"}, + {file = "pyasn1-0.6.0.tar.gz", hash = "sha256:3a35ab2c4b5ef98e17dfdec8ab074046fbda76e281c5a706ccd82328cfc8f64c"}, +] + [[package]] name = "pydantic" version = "2.6.4" @@ -1433,6 +1479,27 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-jose" +version = "3.3.0" +description = "JOSE implementation in Python" +optional = false +python-versions = "*" +files = [ + {file = "python-jose-3.3.0.tar.gz", hash = "sha256:55779b5e6ad599c6336191246e95eb2293a9ddebd555f796a65f838f07e5d78a"}, + {file = "python_jose-3.3.0-py2.py3-none-any.whl", hash = "sha256:9b1376b023f8b298536eedd47ae1089bcdb848f1535ab30555cd92002d78923a"}, +] + +[package.dependencies] +ecdsa = "!=0.15" +pyasn1 = "*" +rsa = "*" + +[package.extras] +cryptography = ["cryptography (>=3.4.0)"] +pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"] +pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] + [[package]] name = "pyyaml" version = "6.0.1" @@ -1634,6 +1701,31 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = false +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -2045,4 +2137,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "27db29dbdef7ff2ecf4c84787bceb8fd404e3029e2bef465b51630232e13fe12" +content-hash = "7a1758f89b9da1c7102e3f6510aae556a6b45292e5b526eeea433d327ef27fdd" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 94ba795..df1e85f 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -24,6 +24,8 @@ case-converter = "^1.1.0" python-box = "^7.1.1" redis = "^5.0.3" langchainhub = "^0.1.15" +python-jose = "^3.3.0" +passlib = "^1.7.4" [build-system]