diff --git a/app.py b/app.py index db449c4..21fe095 100644 --- a/app.py +++ b/app.py @@ -1,12 +1,16 @@ import os +from typing import Annotated +import json from dotenv import load_dotenv +from fastapi.responses import JSONResponse if os.getenv("ENVIRONMENT") == "development": load_dotenv() from fastapi.middleware.cors import CORSMiddleware -from fastapi import FastAPI, Response, Request +from fastapi import Cookie, FastAPI, HTTPException, Response, Request, status + from contextlib import asynccontextmanager from src.tasks.router import router as task_router @@ -20,6 +24,11 @@ from src.settings.database import SessionLocal, engine +allow_origins = os.getenv("ORIGINS") +allow_origins = allow_origins.split(",") +allow_methods = ("*",) +allow_headers = ("*",) + @asynccontextmanager async def lifespan(app: FastAPI): @@ -28,6 +37,14 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) +app.add_middleware( + CORSMiddleware, + allow_origins=allow_origins, + allow_methods=allow_methods, + allow_headers=allow_headers, + allow_credentials=False, + expose_headers=("*",), +) tasks_model.Base.metadata.create_all(bind=engine) projects_model.Base.metadata.create_all(bind=engine) @@ -39,30 +56,57 @@ def health(): return "OK" -@app.get("/", status_code=200) +@app.get("/protected") +def protected(session: Annotated[str | None, Cookie()] = None): + if session is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized" + ) + user_data = json.loads(session) + user_data = dict(email=user_data["email"]) + return JSONResponse(status_code=status.HTTP_200_OK, content=user_data) + + +@app.get("/", status_code=status.HTTP_200_OK) def health(): - return None + return "OK" @app.middleware("http") async def db_session_middleware(request: Request, call_next): - response = Response("Internal server error", status_code=500) + print("db_session_middleware") try: request.state.db = SessionLocal() response = await call_next(request) + except Exception: + response = Response( + "Internal server error", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + ) finally: request.state.db.close() return response -origins = ["*"] +@app.middleware("http") +async def http_session_middleware(request: Request, call_next): + print("http_session_middleware") + try: + origin = request.headers["Origin"] + if origin not in allow_origins: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized" + ) + response = await call_next(request) + response.headers["Access-Control-Allow-Credentials"] = "true" + response.headers["Access-Control-Allow-Origin"] = origin + response.headers["Vary"] = "Origin" + except Exception: + response = Response( + "Internal server error", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + return response + -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_methods=["*"], - allow_headers=["*"], -) app.include_router(task_router) app.include_router(project_router) app.include_router(users_router) diff --git a/requirements.txt b/requirements.txt index 4dc4a1f..05c8514 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -fastapi==0.110.0 +fastapi==0.111.0 uvicorn[standard]==0.28.0 SQLAlchemy==2.0.28 pytest==8.1.1 diff --git a/src/projects/router.py b/src/projects/router.py index e5e402f..caf8da0 100644 --- a/src/projects/router.py +++ b/src/projects/router.py @@ -1,4 +1,5 @@ -from fastapi import APIRouter, Depends, status +from typing import Annotated +from fastapi import APIRouter, Depends, status, Cookie from sqlalchemy.orm import Session from src.settings.database import SessionLocal @@ -16,16 +17,12 @@ def get_db(): router = APIRouter() -@router.get("/projects/{title}", response_model=schemas.ProjectSchema) -async def get_project_by_title(title, db: Session = Depends(get_db)): - return service.get_project_by_title(db, title) - - @router.get("/projects/", response_model=list[schemas.ProjectSchema]) async def read_projects( skip: int = 0, limit: int = 100, db: Session = Depends(get_db), + session: Annotated[str | None, Cookie()] = None, ): projects = service.get_projects(db, skip=skip, limit=limit) return projects diff --git a/src/users/router.py b/src/users/router.py index e250b37..3c5557b 100644 --- a/src/users/router.py +++ b/src/users/router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Depends, Request, Response from sqlalchemy.orm import Session from src.settings.database import SessionLocal @@ -28,4 +28,14 @@ async def login( user: schemas.UserSchema, db: Session = Depends(get_db), ): - return service.login(db, user) + import json + + response = service.login(db, user) + response.set_cookie( + key="session", + value=json.dumps({"email": user.email}), + httponly=True, + samesite="None", + secure=True, + ) + return response diff --git a/src/users/service.py b/src/users/service.py index 0ee3d5f..9e8e135 100644 --- a/src/users/service.py +++ b/src/users/service.py @@ -11,8 +11,13 @@ def validate_user(db: Session, user: schemas.UserSchema): stored_user = db.query(models.User).first() if username == stored_user.username: return JSONResponse(content={"is_user_valid": True}) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, content={"is_user_valid": False} + ) except: - return JSONResponse(content={"is_user_valid": False}) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, content={"is_user_valid": False} + ) def login(db: Session, user: schemas.UserSchema):