Skip to content

Commit

Permalink
ADD: super user ingestion
Browse files Browse the repository at this point in the history
  • Loading branch information
mazzasaverio committed Apr 1, 2024
1 parent 3828f24 commit ca5d82a
Show file tree
Hide file tree
Showing 18 changed files with 496 additions and 29 deletions.
56 changes: 56 additions & 0 deletions backend/app/api/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from collections.abc import Generator
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from pydantic import ValidationError
from sqlmodel import Session

from app.core import security
from backend.app.core.config import settings
from app.core.db import engine
from app.models import TokenPayload, User

reusable_oauth2 = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
)


def get_db() -> Generator[Session, None, None]:
with Session(engine) as session:
yield session


SessionDep = Annotated[Session, Depends(get_db)]
TokenDep = Annotated[str, Depends(reusable_oauth2)]


def get_current_user(session: SessionDep, token: TokenDep) -> User:
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
)
token_data = TokenPayload(**payload)
except (JWTError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
user = session.get(User, token_data.sub)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return user


CurrentUser = Annotated[User, Depends(get_current_user)]


def get_current_active_superuser(current_user: CurrentUser) -> User:
if not current_user.is_superuser:
raise HTTPException(
status_code=400, detail="The user doesn't have enough privileges"
)
return current_user
12 changes: 4 additions & 8 deletions backend/app/api/routes/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import APIRouter


from app.core.config import logger
from backend.app.core.config import logger


from operator import itemgetter
Expand All @@ -19,9 +19,9 @@
from langchain_core.runnables import RunnableParallel
from langchain_community.vectorstores.pgvector import PGVector
from langchain.memory import ConversationBufferMemory
from app.core.config import settings
from backend.app.core.config import settings
from langchain.prompts.prompt import PromptTemplate
from pydantic import BaseModel
from app.schemas.chat_schema import ChatBody

router = APIRouter()

Expand All @@ -32,12 +32,8 @@
chat_config = config.get("CHAT_CONFIG", None)


class ChatRequest(BaseModel):
message: str


@router.post("/chat")
async def chat_action(request: ChatRequest):
async def chat_action(request: ChatBody):

embeddings = OpenAIEmbeddings()

Expand Down
File renamed without changes.
37 changes: 37 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from pydantic_settings import BaseSettings
from typing import List
from loguru import logger
from typing import Annotated, Any, Literal
import sys

from pydantic import (
AnyUrl,
BeforeValidator,
)


def parse_cors(v: Any) -> list[str] | str:
if isinstance(v, str) and not v.startswith("["):
return [i.strip() for i in v.split(",")]
elif isinstance(v, list | str):
return v
raise ValueError(v)


class Settings(BaseSettings):
Expand All @@ -23,6 +38,13 @@ class Settings(BaseSettings):

TAVILY_API_KEY: str

FIRST_SUPERUSER: str
FIRST_SUPERUSER_PASSWORD: str

BACKEND_CORS_ORIGINS: Annotated[list[AnyUrl] | str, BeforeValidator(parse_cors)] = (
[]
)

@property
def ASYNC_DATABASE_URI(self) -> str:
return f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
Expand All @@ -35,4 +57,19 @@ class Config:
env_file = "../.env"


class LogConfig:
LOGGING_LEVEL = "DEBUG"
LOGGING_FORMAT = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level}</level> | <level>{message}</level>"

@staticmethod
def configure_logging():
logger.remove()

logger.add(
sys.stderr, format=LogConfig.LOGGING_FORMAT, level=LogConfig.LOGGING_LEVEL
)


LogConfig.configure_logging()

settings = Settings()
2 changes: 1 addition & 1 deletion backend/app/core/db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlmodel import SQLModel
from sqlalchemy.ext.asyncio import create_async_engine
from app.core.config import settings
from backend.app.core.config import settings
from loguru import logger
from sqlalchemy.orm import sessionmaker
from sqlmodel.ext.asyncio.session import AsyncSession
Expand Down
27 changes: 27 additions & 0 deletions backend/app/core/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from datetime import datetime, timedelta
from typing import Any

from jose import jwt
from passlib.context import CryptContext

from backend.app.core.config import settings

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


ALGORITHM = "HS256"


def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
expire = datetime.utcnow() + expires_delta
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt


def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
53 changes: 53 additions & 0 deletions backend/app/crud/crud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any

from sqlmodel import Session, select

from app.core.security import get_password_hash, verify_password
from app.models import Item, ItemCreate, User, UserCreate, UserUpdate


def create_user(*, session: Session, user_create: UserCreate) -> User:
db_obj = User.model_validate(
user_create, update={"hashed_password": get_password_hash(user_create.password)}
)
session.add(db_obj)
session.commit()
session.refresh(db_obj)
return db_obj


def update_user(*, session: Session, db_user: User, user_in: UserUpdate) -> Any:
user_data = user_in.model_dump(exclude_unset=True)
extra_data = {}
if "password" in user_data:
password = user_data["password"]
hashed_password = get_password_hash(password)
extra_data["hashed_password"] = hashed_password
db_user.sqlmodel_update(user_data, update=extra_data)
session.add(db_user)
session.commit()
session.refresh(db_user)
return db_user


def get_user_by_email(*, session: Session, email: str) -> User | None:
statement = select(User).where(User.email == email)
session_user = session.exec(statement).first()
return session_user


def authenticate(*, session: Session, email: str, password: str) -> User | None:
db_user = get_user_by_email(session=session, email=email)
if not db_user:
return None
if not verify_password(password, db_user.hashed_password):
return None
return db_user


def create_item(*, session: Session, item_in: ItemCreate, owner_id: int) -> Item:
db_item = Item.model_validate(item_in, update={"owner_id": owner_id})
session.add(db_item)
session.commit()
session.refresh(db_item)
return db_item
21 changes: 12 additions & 9 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from contextlib import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware
from app.api.main import api_router
from app.core.config import settings
from backend.app.core.config import settings

from typing import Dict

Expand All @@ -11,14 +11,17 @@
docs_url=f"{settings.API_V1_STR}/docs",
)


app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Set all CORS enabled origins
if settings.BACKEND_CORS_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=[
str(origin).strip("/") for origin in settings.BACKEND_CORS_ORIGINS
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


@app.get("/metrics")
Expand Down
Empty file added backend/ingestion/__init__.py
Empty file.
File renamed without changes.
5 changes: 4 additions & 1 deletion backend/ingestion/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ class Settings(BaseSettings):

OPENAI_API_KEY: str

FIRST_SUPERUSER: str
FIRST_SUPERUSER_PASSWORD: str

@property
def ASYNC_DATABASE_URI(self) -> str:
return f"postgresql+asyncpg://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"

class Config:
env_file = "../.env"
env_file = ".env"


class LogConfig:
Expand Down
61 changes: 60 additions & 1 deletion backend/ingestion/helpers/db.py → backend/ingestion/core/db.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from sqlmodel import SQLModel
from sqlalchemy.ext.asyncio import create_async_engine
from app.core.config import settings
from backend.app.core.config import settings
import asyncpg
import psycopg2
from loguru import logger


from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select

from ingestion.models.user_model import (
UserCreate,
User,
)

from backend.ingestion.crud import user_crud


engine = create_async_engine(str(settings.ASYNC_DATABASE_URI), echo=True)


Expand Down Expand Up @@ -51,6 +63,24 @@ def create_database(database_name, user, password, host, port):
logger.error(f"Error creating database: {e}")


# async def init_db() -> None:
# create_database(
# settings.DB_NAME,
# settings.DB_USER,
# settings.DB_PASS,
# settings.DB_HOST,
# settings.DB_PORT,
# )
# async with engine.begin() as conn:
# # Use run_sync to execute the create_all method in an asynchronous context
# await conn.run_sync(SQLModel.metadata.create_all)

# # Your existing database initialization logic here
# # For example, creating extensions or setting up initial data
# await create_extension()
# logger.info("Database initialized and all tables created if they didn't exist.")


async def init_db() -> None:
create_database(
settings.DB_NAME,
Expand All @@ -67,3 +97,32 @@ async def init_db() -> None:
# For example, creating extensions or setting up initial data
await create_extension()
logger.info("Database initialized and all tables created if they didn't exist.")


from sqlmodel import Session, create_engine, select

from ingestion.crud import user_crud
from sqlmodel import Session

engine = create_engine(str(settings.SYNC_DATABASE_URI))


# make sure all SQLModel models are imported (app.models) before initializing DB
# otherwise, SQLModel might fail to initialize relationships properly
# for more details: https://github.com/tiangolo/full-stack-fastapi-template/issues/28


def create_super_user() -> None:

with Session(engine) as session:

user = session.exec(
select(User).where(User.email == settings.FIRST_SUPERUSER)
).first()
if not user:
user_in = UserCreate(
email=settings.FIRST_SUPERUSER,
password=settings.FIRST_SUPERUSER_PASSWORD,
is_superuser=True,
)
user = user_crud.create_user(session=session, user_create=user_in)
16 changes: 16 additions & 0 deletions backend/ingestion/crud/user_crud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Any

from sqlmodel import Session, select

from app.core.security import get_password_hash
from ingestion.models.user_model import User, UserCreate


def create_user(*, session: Session, user_create: UserCreate) -> User:
db_obj = User.model_validate(
user_create, update={"hashed_password": get_password_hash(user_create.password)}
)
session.add(db_obj)
session.commit()
session.refresh(db_obj)
return db_obj
Loading

0 comments on commit ca5d82a

Please sign in to comment.