Skip to content

Commit

Permalink
[ENH] Add authentication to /query route (#323)
Browse files Browse the repository at this point in the history
* add implicit OAuth flow + Google token verification for /query route

* add dependencies for Google auth library

* check auth env vars on startup

* mock token/token verification and disable auth as needed in tests

* add tests of auth utilities and filter irrelevant warnings

* test empty query succeeds when auth is disabled
  • Loading branch information
alyssadai authored Jul 17, 2024
1 parent 7f03522 commit 07a6afa
Show file tree
Hide file tree
Showing 9 changed files with 383 additions and 55 deletions.
30 changes: 27 additions & 3 deletions app/api/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,41 @@

from typing import List

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2

from .. import crud
from .. import crud, security
from ..models import CohortQueryResponse, QueryModel
from ..security import verify_token

router = APIRouter(prefix="/query", tags=["query"])

# Adapted from info in https://github.com/tiangolo/fastapi/discussions/9137#discussioncomment-5157382
oauth2_scheme = OAuth2(
flows={
"implicit": {
"authorizationUrl": "https://accounts.google.com/o/oauth2/auth",
}
},
# Don't automatically error out when request is not authenticated, to support optional authentication
auto_error=False,
)


@router.get("/", response_model=List[CohortQueryResponse])
async def get_query(query: QueryModel = Depends(QueryModel)):
async def get_query(
query: QueryModel = Depends(QueryModel),
token: str | None = Depends(oauth2_scheme),
):
"""When a GET request is sent, return list of dicts corresponding to subject-level metadata aggregated by dataset."""
if security.AUTH_ENABLED:
if token is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated",
)
verify_token(token)

response = await crud.get(
query.min_age,
query.max_age,
Expand Down
45 changes: 45 additions & 0 deletions app/api/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Functions for handling authentication. Same ones as used in Neurobagel's federation API."""

import os

from fastapi import HTTPException, status
from fastapi.security.utils import get_authorization_scheme_param
from google.auth.exceptions import GoogleAuthError
from google.auth.transport import requests
from google.oauth2 import id_token

AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true"
CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None)


def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
# By default, if CLIENT_ID is not provided to verify_oauth2_token,
# Google will simply skip verifying the audience claim of ID tokens.
# This however can be a security risk, so we mandate that CLIENT_ID is set.
if AUTH_ENABLED and CLIENT_ID is None:
raise ValueError(
"Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. "
"Please set NB_QUERY_CLIENT_ID to the Google client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
)


def verify_token(token: str):
"""Verify the Google ID token. Raise an HTTPException if the token is invalid."""
# Adapted from https://developers.google.com/identity/gsi/web/guides/verify-google-id-token#python
try:
# Extract the token from the "Bearer" scheme
# (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485)
# TODO: Check also if scheme of token is "Bearer"?
_, param = get_authorization_scheme_param(token)
id_info = id_token.verify_oauth2_token(
param, requests.Request(), CLIENT_ID
)
# TODO: Remove print statement or turn into logging
print("Token verified: ", id_info)
except (GoogleAuthError, ValueError) as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {exc}",
headers={"WWW-Authenticate": "Bearer"},
) from exc
10 changes: 9 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .api import utility as util
from .api.routers import attributes, query
from .api.security import check_client_id

app = FastAPI(
default_response_class=ORJSONResponse, docs_url=None, redoc_url=None
Expand Down Expand Up @@ -77,7 +78,14 @@ def overridden_redoc():

@app.on_event("startup")
async def auth_check():
"""Checks whether username and password environment variables are set."""
"""
Checks whether authentication has been enabled for API queries and whether the
username and password environment variables for the graph backend have been set.
TODO: Refactor once startup events have been replaced by lifespan event
"""
check_client_id()

if (
# TODO: Check if this error is still raised when variables are empty strings
os.environ.get(util.GRAPH_USERNAME.name) is None
Expand Down
14 changes: 13 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
anyio==3.6.2
attrs==22.1.0
cachetools==5.3.3
certifi==2024.7.4
cfgv==3.3.1
coverage==7.0.0
charset-normalizer==3.3.2
click==8.1.3
colorama==0.4.6
coverage==7.0.0
distlib==0.3.6
exceptiongroup==1.0.4
fastapi==0.110.1
filelock==3.8.0
google-auth==2.32.0
h11==0.14.0
httpcore==0.16.2
httpx==0.23.1
Expand All @@ -22,15 +26,23 @@ pandas==1.5.2
platformdirs==2.5.4
pluggy==1.0.0
pre-commit==3.6.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pydantic==1.10.13
pyparsing==3.0.9
pytest==7.2.0
python-dateutil==2.8.2
pytz==2022.7
PyYAML==6.0
requests==2.32.3
rfc3986==1.5.0
rsa==4.9
six==1.16.0
sniffio==1.3.0
starlette==0.37.2
toml==0.10.2
tomli==2.0.1
typing_extensions==4.11.0
urllib3==2.2.2
uvicorn==0.20.0
virtualenv==20.16.7
33 changes: 33 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,46 @@ def test_app():
yield client


@pytest.fixture
def disable_auth(monkeypatch):
"""
Disable the authentication requirement for the API to skip startup checks
(for when the tested route does not require authentication).
"""
monkeypatch.setattr("app.api.security.AUTH_ENABLED", False)


@pytest.fixture(scope="function")
def set_test_credentials(monkeypatch):
"""Set random username and password to avoid error from startup check for set credentials."""
monkeypatch.setenv(util.GRAPH_USERNAME.name, "SomeUser")
monkeypatch.setenv(util.GRAPH_PASSWORD.name, "SomePassword")


@pytest.fixture()
def mock_verify_token():
"""Mock a successful token verification that does not raise any exceptions."""

def _verify_token(token):
return None

return _verify_token


@pytest.fixture()
def set_mock_verify_token(monkeypatch, mock_verify_token):
"""Set the verify_token function to a mock that does not raise any exceptions."""
monkeypatch.setattr(
"app.api.routers.query.verify_token", mock_verify_token
)


@pytest.fixture()
def mock_auth_header() -> dict:
"""Create an authorization header with a mock token that is well-formed for testing purposes."""
return {"Authorization": "Bearer foo"}


@pytest.fixture()
def test_data():
"""Create valid aggregate response data for two toy datasets for testing."""
Expand Down
34 changes: 25 additions & 9 deletions tests/test_app_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from app.api import utility as util


def test_start_app_without_environment_vars_fails(test_app, monkeypatch):
@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_start_app_without_environment_vars_fails(
test_app, monkeypatch, disable_auth
):
"""Given non-existing username and password environment variables, raises an informative RuntimeError."""
monkeypatch.delenv(util.GRAPH_USERNAME.name, raising=False)
monkeypatch.delenv(util.GRAPH_PASSWORD.name, raising=False)
Expand All @@ -24,21 +27,27 @@ def test_start_app_without_environment_vars_fails(test_app, monkeypatch):
)


def test_app_with_invalid_environment_vars(test_app, monkeypatch):
"""Given invalid environment variables, returns a 401 status code."""
@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_app_with_invalid_environment_vars(
test_app, monkeypatch, mock_auth_header, set_mock_verify_token
):
"""Given invalid environment variables for the graph, returns a 401 status code."""
monkeypatch.setenv(util.GRAPH_USERNAME.name, "something")
monkeypatch.setenv(util.GRAPH_PASSWORD.name, "cool")

def mock_httpx_post(**kwargs):
return httpx.Response(status_code=401)

monkeypatch.setattr(httpx, "post", mock_httpx_post)
response = test_app.get("/query/")
response = test_app.get("/query/", headers=mock_auth_header)
assert response.status_code == 401


def test_app_with_unset_allowed_origins(
test_app, monkeypatch, set_test_credentials
test_app,
monkeypatch,
set_test_credentials,
disable_auth,
):
"""Tests that when the environment variable for allowed origins has not been set, a warning is raised and the app uses a default value."""
monkeypatch.delenv(util.ALLOWED_ORIGINS.name, raising=False)
Expand Down Expand Up @@ -90,6 +99,7 @@ def test_app_with_set_allowed_origins(
allowed_origins,
parsed_origins,
expectation,
disable_auth,
):
"""
Test that when the environment variable for allowed origins has been explicitly set, the app correctly parses it into a list
Expand All @@ -108,8 +118,11 @@ def test_app_with_set_allowed_origins(
)


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_stored_vocab_lookup_file_created_on_startup(
test_app, set_test_credentials
test_app,
set_test_credentials,
disable_auth,
):
"""Test that on startup, a non-empty temporary lookup file is created for term ID-label mappings for the locally stored SNOMED CT vocabulary."""
with test_app:
Expand All @@ -118,8 +131,9 @@ def test_stored_vocab_lookup_file_created_on_startup(
assert term_labels_path.stat().st_size > 0


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_external_vocab_is_fetched_on_startup(
test_app, monkeypatch, set_test_credentials
test_app, monkeypatch, set_test_credentials, disable_auth
):
"""
Tests that on startup, a GET request is made to the Cognitive Atlas API and that when the request succeeds,
Expand Down Expand Up @@ -160,8 +174,9 @@ def mock_httpx_get(**kwargs):
}


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_failed_vocab_fetching_on_startup_raises_warning(
test_app, monkeypatch, set_test_credentials
test_app, monkeypatch, set_test_credentials, disable_auth
):
"""
Tests that when a GET request to the Cognitive Atlas API has a non-success response code (e.g., due to service being unavailable),
Expand All @@ -186,8 +201,9 @@ def mock_httpx_get(**kwargs):
)


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_network_error_on_startup_raises_warning(
test_app, monkeypatch, set_test_credentials
test_app, monkeypatch, set_test_credentials, disable_auth
):
"""
Tests that when a GET request to the Cognitive Atlas API fails due to a network error (i.e., while issuing the request),
Expand Down
9 changes: 5 additions & 4 deletions tests/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
from app.api import utility as util


def test_root(test_app, set_test_credentials):
def test_root(test_app):
"""Given a GET request to the root endpoint, Check for 200 status and expected content."""

with test_app:
response = test_app.get("/")
response = test_app.get("/")

assert response.status_code == 200
assert "Welcome to the Neurobagel REST API!" in response.text
assert '<a href="/docs">documentation</a>' in response.text


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
@pytest.mark.parametrize(
"valid_data_element_URI",
["nb:Diagnosis", "nb:Assessment"],
Expand All @@ -28,6 +28,7 @@ def test_get_terms_valid_data_element_URI(
mock_successful_get_terms,
valid_data_element_URI,
monkeypatch,
disable_auth,
):
"""Given a valid data element URI, returns a 200 status code and a non-empty list of terms for that data element."""
monkeypatch.setattr(crud, "get_terms", mock_successful_get_terms)
Expand All @@ -54,7 +55,7 @@ def test_get_terms_invalid_data_element_URI(


def test_get_terms_for_attribute_with_vocab_lookup(
test_app, monkeypatch, set_test_credentials
test_app, monkeypatch, set_test_credentials, disable_auth
):
"""
Given a valid data element URI with a vocabulary lookup file available, returns prefixed term URIs and their human-readable labels (where found)
Expand Down
Loading

0 comments on commit 07a6afa

Please sign in to comment.