diff --git a/xas-standards-api/.pre-commit-config.yaml b/xas-standards-api/.pre-commit-config.yaml index 5bc9f00..6bbed96 100644 --- a/xas-standards-api/.pre-commit-config.yaml +++ b/xas-standards-api/.pre-commit-config.yaml @@ -19,5 +19,5 @@ repos: name: Run ruff stages: [commit] language: system - entry: ruff + entry: ruff check types: [python] diff --git a/xas-standards-api/pyproject.toml b/xas-standards-api/pyproject.toml index e8b3987..2872910 100644 --- a/xas-standards-api/pyproject.toml +++ b/xas-standards-api/pyproject.toml @@ -104,7 +104,7 @@ commands = [tool.ruff] src = ["src", "tests"] line-length = 88 -select = [ +lint.select = [ "C4", # flake8-comprehensions - https://beta.ruff.rs/docs/rules/#flake8-comprehensions-c4 "E", # pycodestyle errors - https://beta.ruff.rs/docs/rules/#error-e "F", # pyflakes rules - https://beta.ruff.rs/docs/rules/#pyflakes-f diff --git a/xas-standards-api/src/xas_standards_api/app.py b/xas-standards-api/src/xas_standards_api/app.py index 30f1522..f4ea506 100644 --- a/xas-standards-api/src/xas_standards_api/app.py +++ b/xas-standards-api/src/xas_standards_api/app.py @@ -2,13 +2,25 @@ import os from typing import Annotated, List, Optional, Union -from fastapi import Depends, FastAPI, File, Form, Query, UploadFile +import requests +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Query, + UploadFile, + status, +) from fastapi.responses import HTMLResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from fastapi.staticfiles import StaticFiles from fastapi_pagination import add_pagination from fastapi_pagination.cursor import CursorPage from fastapi_pagination.ext.sqlalchemy import paginate from sqlmodel import Session, create_engine, select +from starlette.responses import RedirectResponse from .crud import ( add_new_standard, @@ -35,11 +47,19 @@ ) dev = False -lifespan = None +env_value = os.environ.get("FASTAPI_APP_ENV") + +if env_value and env_value == "development": + print("RUNNING IN DEV MODE") + dev = True + +get_bearer_token = HTTPBearer(auto_error=True) url = os.environ.get("POSTGRESURL") build_dir = os.environ.get("FRONTEND_BUILD_DIR") +oidc_user_info_endpoint = os.environ.get("OIDC_USER_INFO_ENDPOINT") + if url: engine = create_engine(url) @@ -52,16 +72,59 @@ def get_session(): yield session -app = FastAPI(lifespan=lifespan) +app = FastAPI() CursorPage = CursorPage.with_custom_options( size=Query(10, ge=1, le=100), ) - add_pagination(app) +@app.get("/login", response_class=RedirectResponse) +async def redirect_home(): + # proxy handles log in so if you reach here go home + return "/" + + +async def get_current_user( + auth: HTTPAuthorizationCredentials = Depends(get_bearer_token), +): + + if auth is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid user token", + ) + + if dev: + return auth.credentials + + if oidc_user_info_endpoint is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="User info endpoint error", + ) + + response = requests.get( + url=oidc_user_info_endpoint, + headers={"Authorization": f"Bearer {auth.credentials}"}, + ) + + if response.status_code == 401: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid user token", + ) + + return response.json()["id"] + + +@app.get("/api/user") +async def check(user_id: str = Depends(get_current_user)): + return {"user": user_id} + + @app.get("/api/metadata") def read_metadata(session: Session = Depends(get_session)) -> MetadataResponse: return get_metadata(session) @@ -142,13 +205,14 @@ def add_standard_file( licence: Annotated[str, Form()], additional_files: Optional[list[UploadFile]] = Form(None), sample_comp: Optional[str] = Form(None), + user_id: str = Depends(get_current_user), session: Session = Depends(get_session), ) -> XASStandard: if additional_files: print(f"Additional files {len(additional_files)}") - person = select_or_create_person(session, "test1234") + person = select_or_create_person(session, user_id) form_input = XASStandardInput( submitter_id=person.id, diff --git a/xas-standards-api/src/xas_standards_api/crud.py b/xas-standards-api/src/xas_standards_api/crud.py index 6ed5d77..0ac5d2d 100644 --- a/xas-standards-api/src/xas_standards_api/crud.py +++ b/xas-standards-api/src/xas_standards_api/crud.py @@ -81,23 +81,23 @@ def add_new_standard(session, file1, xs_input: XASStandardInput, additional_file tmp_filename = pvc_location + str(uuid.uuid4()) with open(tmp_filename, "wb") as ntf: - filename = ntf.name ntf.write(file1.file.read()) - xdi_data = xdi.read_xdi(filename) - set_labels = set(xdi_data.array_labels) + xdi_data = xdi.read_xdi(tmp_filename) - fluorescence = "mufluor" in set_labels - transmission = "mutrans" in set_labels - emission = "mutey" in set_labels + set_labels = set(xdi_data.array_labels) - xsd = XASStandardDataInput( - fluorescence=fluorescence, - location=tmp_filename, - original_filename=file1.filename, - emission=emission, - transmission=transmission, - ) + fluorescence = "mufluor" in set_labels + transmission = "mutrans" in set_labels + emission = "mutey" in set_labels + + xsd = XASStandardDataInput( + fluorescence=fluorescence, + location=tmp_filename, + original_filename=file1.filename, + emission=emission, + transmission=transmission, + ) new_standard = XASStandard.model_validate(xs_input) new_standard.xas_standard_data = XASStandardData.model_validate(xsd) diff --git a/xas-standards-client/src/App.tsx b/xas-standards-client/src/App.tsx index 54d1d4a..27a6e71 100644 --- a/xas-standards-client/src/App.tsx +++ b/xas-standards-client/src/App.tsx @@ -7,19 +7,33 @@ import StandardSubmission from "./components/StandardSubmission.tsx"; import WelcomePage from "./components/WelcomePage.tsx"; import { MetadataProvider } from "./contexts/MetadataContext.tsx"; +import { UserProvider } from "./contexts/UserContext.tsx"; + +import LogInPage from "./components/LogInPage.tsx"; +import RequireAuth from "./components/RequireAuth.tsx"; function App() { return (