Skip to content

Commit

Permalink
Merge pull request #9 from black-redoc/feat/auth
Browse files Browse the repository at this point in the history
feat/auth
  • Loading branch information
black-redoc committed Jul 24, 2024
2 parents c2a6ef3 + 3c9e18b commit aa2b77b
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 21 deletions.
66 changes: 55 additions & 11 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
from typing import Annotated
import json
from dotenv import load_dotenv
from fastapi.responses import JSONResponse

if os.getenv("ENVIRONMENT") == "development":
load_dotenv()

from fastapi.middleware.cors import CORSMiddleware

from fastapi import FastAPI, Response, Request
from fastapi import Cookie, FastAPI, HTTPException, Response, Request, status

from contextlib import asynccontextmanager

from src.tasks.router import router as task_router
Expand All @@ -20,6 +24,11 @@

from src.settings.database import SessionLocal, engine

allow_origins = os.getenv("ORIGINS")
allow_origins = allow_origins.split(",")
allow_methods = ("*",)
allow_headers = ("*",)


@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -28,6 +37,14 @@ async def lifespan(app: FastAPI):


app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_methods=allow_methods,
allow_headers=allow_headers,
allow_credentials=False,
expose_headers=("*",),
)

tasks_model.Base.metadata.create_all(bind=engine)
projects_model.Base.metadata.create_all(bind=engine)
Expand All @@ -39,30 +56,57 @@ def health():
return "OK"


@app.get("/", status_code=200)
@app.get("/protected")
def protected(session: Annotated[str | None, Cookie()] = None):
if session is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
user_data = json.loads(session)
user_data = dict(email=user_data["email"])
return JSONResponse(status_code=status.HTTP_200_OK, content=user_data)


@app.get("/", status_code=status.HTTP_200_OK)
def health():
return None
return "OK"


@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
response = Response("Internal server error", status_code=500)
print("db_session_middleware")
try:
request.state.db = SessionLocal()
response = await call_next(request)
except Exception:
response = Response(
"Internal server error", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
finally:
request.state.db.close()
return response


origins = ["*"]
@app.middleware("http")
async def http_session_middleware(request: Request, call_next):
print("http_session_middleware")
try:
origin = request.headers["Origin"]
if origin not in allow_origins:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
response = await call_next(request)
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Vary"] = "Origin"
except Exception:
response = Response(
"Internal server error", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return response


app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(task_router)
app.include_router(project_router)
app.include_router(users_router)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
fastapi==0.110.0
fastapi==0.111.0
uvicorn[standard]==0.28.0
SQLAlchemy==2.0.28
pytest==8.1.1
Expand Down
9 changes: 3 additions & 6 deletions src/projects/router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, status
from typing import Annotated
from fastapi import APIRouter, Depends, status, Cookie
from sqlalchemy.orm import Session

from src.settings.database import SessionLocal
Expand All @@ -16,16 +17,12 @@ def get_db():
router = APIRouter()


@router.get("/projects/{title}", response_model=schemas.ProjectSchema)
async def get_project_by_title(title, db: Session = Depends(get_db)):
return service.get_project_by_title(db, title)


@router.get("/projects/", response_model=list[schemas.ProjectSchema])
async def read_projects(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
session: Annotated[str | None, Cookie()] = None,
):
projects = service.get_projects(db, skip=skip, limit=limit)
return projects
Expand Down
14 changes: 12 additions & 2 deletions src/users/router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, status
from fastapi import APIRouter, Depends, Request, Response
from sqlalchemy.orm import Session

from src.settings.database import SessionLocal
Expand Down Expand Up @@ -28,4 +28,14 @@ async def login(
user: schemas.UserSchema,
db: Session = Depends(get_db),
):
return service.login(db, user)
import json

response = service.login(db, user)
response.set_cookie(
key="session",
value=json.dumps({"email": user.email}),
httponly=True,
samesite="None",
secure=True,
)
return response
7 changes: 6 additions & 1 deletion src/users/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ def validate_user(db: Session, user: schemas.UserSchema):
stored_user = db.query(models.User).first()
if username == stored_user.username:
return JSONResponse(content={"is_user_valid": True})
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, content={"is_user_valid": False}
)
except:
return JSONResponse(content={"is_user_valid": False})
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, content={"is_user_valid": False}
)


def login(db: Session, user: schemas.UserSchema):
Expand Down

0 comments on commit aa2b77b

Please sign in to comment.