-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3828f24
commit ca5d82a
Showing
18 changed files
with
496 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from collections.abc import Generator | ||
from typing import Annotated | ||
|
||
from fastapi import Depends, HTTPException, status | ||
from fastapi.security import OAuth2PasswordBearer | ||
from jose import JWTError, jwt | ||
from pydantic import ValidationError | ||
from sqlmodel import Session | ||
|
||
from app.core import security | ||
from backend.app.core.config import settings | ||
from app.core.db import engine | ||
from app.models import TokenPayload, User | ||
|
||
reusable_oauth2 = OAuth2PasswordBearer( | ||
tokenUrl=f"{settings.API_V1_STR}/login/access-token" | ||
) | ||
|
||
|
||
def get_db() -> Generator[Session, None, None]: | ||
with Session(engine) as session: | ||
yield session | ||
|
||
|
||
SessionDep = Annotated[Session, Depends(get_db)] | ||
TokenDep = Annotated[str, Depends(reusable_oauth2)] | ||
|
||
|
||
def get_current_user(session: SessionDep, token: TokenDep) -> User: | ||
try: | ||
payload = jwt.decode( | ||
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] | ||
) | ||
token_data = TokenPayload(**payload) | ||
except (JWTError, ValidationError): | ||
raise HTTPException( | ||
status_code=status.HTTP_403_FORBIDDEN, | ||
detail="Could not validate credentials", | ||
) | ||
user = session.get(User, token_data.sub) | ||
if not user: | ||
raise HTTPException(status_code=404, detail="User not found") | ||
if not user.is_active: | ||
raise HTTPException(status_code=400, detail="Inactive user") | ||
return user | ||
|
||
|
||
CurrentUser = Annotated[User, Depends(get_current_user)] | ||
|
||
|
||
def get_current_active_superuser(current_user: CurrentUser) -> User: | ||
if not current_user.is_superuser: | ||
raise HTTPException( | ||
status_code=400, detail="The user doesn't have enough privileges" | ||
) | ||
return current_user |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from datetime import datetime, timedelta | ||
from typing import Any | ||
|
||
from jose import jwt | ||
from passlib.context import CryptContext | ||
|
||
from backend.app.core.config import settings | ||
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | ||
|
||
|
||
ALGORITHM = "HS256" | ||
|
||
|
||
def create_access_token(subject: str | Any, expires_delta: timedelta) -> str: | ||
expire = datetime.utcnow() + expires_delta | ||
to_encode = {"exp": expire, "sub": str(subject)} | ||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) | ||
return encoded_jwt | ||
|
||
|
||
def verify_password(plain_password: str, hashed_password: str) -> bool: | ||
return pwd_context.verify(plain_password, hashed_password) | ||
|
||
|
||
def get_password_hash(password: str) -> str: | ||
return pwd_context.hash(password) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import Any | ||
|
||
from sqlmodel import Session, select | ||
|
||
from app.core.security import get_password_hash, verify_password | ||
from app.models import Item, ItemCreate, User, UserCreate, UserUpdate | ||
|
||
|
||
def create_user(*, session: Session, user_create: UserCreate) -> User: | ||
db_obj = User.model_validate( | ||
user_create, update={"hashed_password": get_password_hash(user_create.password)} | ||
) | ||
session.add(db_obj) | ||
session.commit() | ||
session.refresh(db_obj) | ||
return db_obj | ||
|
||
|
||
def update_user(*, session: Session, db_user: User, user_in: UserUpdate) -> Any: | ||
user_data = user_in.model_dump(exclude_unset=True) | ||
extra_data = {} | ||
if "password" in user_data: | ||
password = user_data["password"] | ||
hashed_password = get_password_hash(password) | ||
extra_data["hashed_password"] = hashed_password | ||
db_user.sqlmodel_update(user_data, update=extra_data) | ||
session.add(db_user) | ||
session.commit() | ||
session.refresh(db_user) | ||
return db_user | ||
|
||
|
||
def get_user_by_email(*, session: Session, email: str) -> User | None: | ||
statement = select(User).where(User.email == email) | ||
session_user = session.exec(statement).first() | ||
return session_user | ||
|
||
|
||
def authenticate(*, session: Session, email: str, password: str) -> User | None: | ||
db_user = get_user_by_email(session=session, email=email) | ||
if not db_user: | ||
return None | ||
if not verify_password(password, db_user.hashed_password): | ||
return None | ||
return db_user | ||
|
||
|
||
def create_item(*, session: Session, item_in: ItemCreate, owner_id: int) -> Item: | ||
db_item = Item.model_validate(item_in, update={"owner_id": owner_id}) | ||
session.add(db_item) | ||
session.commit() | ||
session.refresh(db_item) | ||
return db_item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from typing import Any | ||
|
||
from sqlmodel import Session, select | ||
|
||
from app.core.security import get_password_hash | ||
from ingestion.models.user_model import User, UserCreate | ||
|
||
|
||
def create_user(*, session: Session, user_create: UserCreate) -> User: | ||
db_obj = User.model_validate( | ||
user_create, update={"hashed_password": get_password_hash(user_create.password)} | ||
) | ||
session.add(db_obj) | ||
session.commit() | ||
session.refresh(db_obj) | ||
return db_obj |
Oops, something went wrong.