diff --git a/backend/app/api.py b/backend/app/api.py index 87fc163..842617c 100644 --- a/backend/app/api.py +++ b/backend/app/api.py @@ -1,32 +1,31 @@ -import csv import json -from io import TextIOWrapper -from collections import defaultdict +import secrets +import socket from datetime import datetime -from typing import DefaultDict, List, Optional, Set, Tuple +from typing import List, Optional +from bson import ObjectId from fastapi import ( APIRouter, - Request, BackgroundTasks, Depends, - Query, - HTTPException, - status, - Form, File, + Form, + HTTPException, + Query, + Request, + Response, UploadFile, + status, ) -from bson import ObjectId +from fastapi.param_functions import Cookie +from fastapi.responses import HTMLResponse +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from passlib.hash import bcrypt from app.config import Settings, get_settings -from app.models import * from app.database import DatabaseContext - -from fastapi.security import HTTPBasic, HTTPBasicCredentials -from passlib.hash import bcrypt -from fastapi.responses import HTMLResponse -import socket +from app.models import * security = HTTPBasic() router = APIRouter() @@ -42,7 +41,9 @@ def verify_password(plain_password, hashed_password: SecretStr): async def get_current_user( + response: Response, credentials: HTTPBasicCredentials = Depends(security), + session_id: Optional[str] = Cookie(None), db: DatabaseContext = Depends(get_db), ) -> Annotator: user = await db.get_annotator(credentials.username) @@ -50,7 +51,7 @@ async def get_current_user( hostname = socket.gethostname() ip_address = socket.gethostbyname(hostname) # If user has an active session skip bcrypt validation: - if await db.active_session(ip_address): + if session_id and await db.active_session(ip_address, session_id): return user # Else try to login user if not user or not verify_password(credentials.password, user.hashed_password): @@ -60,7 +61,16 @@ async def get_current_user( headers={"WWW-Authenticate": "Basic"}, ) # Create active session for newly logged in user - await db.create_session(ip_address) + session_id = secrets.token_urlsafe() + await db.create_session(ip_address, session_id) + response.set_cookie( + "session_id", + session_id, + secure=True, + httponly=True, + samesite="strict", + max_age=15 * 60, + ) return user @@ -87,7 +97,6 @@ async def audit_handler( query_params=request.query_params, body=body, ) - background.add_task(db.add_audit_event, audit_event) @@ -206,7 +215,7 @@ async def get_segment( segment_id: str, annotator_username: str = Query(None, alias="annotator"), db: DatabaseContext = Depends(get_db), - username: str = Depends(get_current_user), + _=Depends(get_current_user), ): segment: SegmentRecord = await db.get_segment(ObjectId(segment_id)) annotation = segment.annotations.get(annotator_username) @@ -222,7 +231,7 @@ async def update_segment_annotations( annotator_username: str, annotation: Annotation, db: DatabaseContext = Depends(get_db), - username: str = Depends(get_current_user), + _=Depends(get_current_user), ): try: await db.update_annotation(ObjectId(segment_id), annotator_username, annotation) diff --git a/backend/app/database.py b/backend/app/database.py index 5e6845f..058c9d4 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -35,20 +35,20 @@ def __enter__(self): def __exit__(self, *_): self.client.close() - async def active_session(self, user_IP: str): + async def active_session(self, user_ip: str, session_id: str): return await self.sessions.find_one_and_update( - {"user_IP": user_IP}, - {"$set": - {"lastLoginAt": datetime.now()} - } - ) - - async def create_session(self, user_IP: str): - await self.sessions.insert_one({ - "lastLoginAt": datetime.now(), - "user_IP": user_IP - }) - return + {"user_ip": user_ip, "session_id": session_id}, + {"$set": {"last_login": datetime.utcnow()}}, + ) + + async def create_session(self, user_ip: str, session_id: str): + await self.sessions.insert_one( + { + "user_ip": user_ip, + "session_id": session_id, + "last_login": datetime.utcnow(), + } + ) async def set_campaign(self, username: str, campaign: AnnotationCampaign): return await self.annotators.update_one( diff --git a/backend/main.py b/backend/main.py index ae76d99..190b422 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,14 +1,8 @@ -from pathlib import Path - from fastapi import FastAPI, Depends, Form -from fastapi.staticfiles import StaticFiles -from fastapi.middleware.gzip import GZipMiddleware - from app.api import router as api_router, audit_handler app = FastAPI() -app.add_middleware(GZipMiddleware) app.include_router(api_router, dependencies=[Depends(audit_handler)]) # Dev runner diff --git a/backend/seed_db.py b/backend/seed_db.py index 4dbfe31..8497d4b 100644 --- a/backend/seed_db.py +++ b/backend/seed_db.py @@ -3,7 +3,7 @@ from app.config import get_settings from app.models import SegmentRecord -from pymongo import MongoClient +from pymongo import ASCENDING, HASHED, MongoClient from bson import ObjectId # bcrypt.hash("12345") @@ -68,11 +68,19 @@ def generate_annotation(i: int, n_samples=SIGNAL_HZ * SECONDS): db = client.get_database(settings.db_name) annotators = db.get_collection("annotators") segments = db.get_collection("segment_records") + sessions = db.get_collection("active_sessions") with client.start_session() as sess: annotators.drop(session=sess) segments.drop(session=sess) + sessions.drop(session=sess) db.get_collection("audit_events").drop(session=sess) + + sessions.create_index( + [("user_ip", ASCENDING), ("session_id", HASHED)], + session=sess, + ) + sessions.create_index("last_login", expireAfterSeconds=15 * 60, session=sess) with sess.start_transaction(): annotators.insert_many(ANNOTATORS) segments.insert_many([generate_annotation(i) for i in range(N)]) @@ -80,12 +88,5 @@ def generate_annotation(i: int, n_samples=SIGNAL_HZ * SECONDS): segment_ids = [s["_id"] for s in segments.find(projection=[])] annotators.update_many( {}, - { - "$set": { - "current_campaign": { - "name": "training", - "segments": segment_ids - } - } - }, + {"$set": {"current_campaign": {"name": "training", "segments": segment_ids}}}, )