Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add authentication to /query route #323

Merged
merged 6 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions app/api/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,41 @@

from typing import List

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2

from .. import crud
from .. import crud, security
from ..models import CohortQueryResponse, QueryModel
from ..security import verify_token

router = APIRouter(prefix="/query", tags=["query"])

# Adapted from info in https://github.com/tiangolo/fastapi/discussions/9137#discussioncomment-5157382
oauth2_scheme = OAuth2(
flows={
"implicit": {
"authorizationUrl": "https://accounts.google.com/o/oauth2/auth",
}
},
# Don't automatically error out when request is not authenticated, to support optional authentication
auto_error=False,
)


@router.get("/", response_model=List[CohortQueryResponse])
async def get_query(query: QueryModel = Depends(QueryModel)):
async def get_query(
query: QueryModel = Depends(QueryModel),
token: str | None = Depends(oauth2_scheme),
):
"""When a GET request is sent, return list of dicts corresponding to subject-level metadata aggregated by dataset."""
if security.AUTH_ENABLED:
if token is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated",
)
verify_token(token)

response = await crud.get(
query.min_age,
query.max_age,
Expand Down
45 changes: 45 additions & 0 deletions app/api/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Functions for handling authentication. Same ones as used in Neurobagel's federation API."""

import os

from fastapi import HTTPException, status
from fastapi.security.utils import get_authorization_scheme_param
from google.auth.exceptions import GoogleAuthError
from google.auth.transport import requests
from google.oauth2 import id_token

AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true"
CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None)


def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
# By default, if CLIENT_ID is not provided to verify_oauth2_token,
# Google will simply skip verifying the audience claim of ID tokens.
# This however can be a security risk, so we mandate that CLIENT_ID is set.
if AUTH_ENABLED and CLIENT_ID is None:
raise ValueError(
"Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. "
"Please set NB_QUERY_CLIENT_ID to the Google client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
)


def verify_token(token: str):
"""Verify the Google ID token. Raise an HTTPException if the token is invalid."""
# Adapted from https://developers.google.com/identity/gsi/web/guides/verify-google-id-token#python
try:
# Extract the token from the "Bearer" scheme
# (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485)
# TODO: Check also if scheme of token is "Bearer"?
_, param = get_authorization_scheme_param(token)
id_info = id_token.verify_oauth2_token(
param, requests.Request(), CLIENT_ID
)
# TODO: Remove print statement or turn into logging
print("Token verified: ", id_info)
except (GoogleAuthError, ValueError) as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {exc}",
headers={"WWW-Authenticate": "Bearer"},
) from exc
10 changes: 9 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .api import utility as util
from .api.routers import attributes, query
from .api.security import check_client_id

app = FastAPI(
default_response_class=ORJSONResponse, docs_url=None, redoc_url=None
Expand Down Expand Up @@ -77,7 +78,14 @@ def overridden_redoc():

@app.on_event("startup")
async def auth_check():
"""Checks whether username and password environment variables are set."""
"""
Checks whether authentication has been enabled for API queries and whether the
username and password environment variables for the graph backend have been set.

TODO: Refactor once startup events have been replaced by lifespan event
"""
check_client_id()

if (
# TODO: Check if this error is still raised when variables are empty strings
os.environ.get(util.GRAPH_USERNAME.name) is None
Expand Down
14 changes: 13 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
anyio==3.6.2
attrs==22.1.0
cachetools==5.3.3
certifi==2024.7.4
cfgv==3.3.1
coverage==7.0.0
charset-normalizer==3.3.2
click==8.1.3
colorama==0.4.6
coverage==7.0.0
distlib==0.3.6
exceptiongroup==1.0.4
fastapi==0.110.1
filelock==3.8.0
google-auth==2.32.0
h11==0.14.0
httpcore==0.16.2
httpx==0.23.1
Expand All @@ -22,15 +26,23 @@ pandas==1.5.2
platformdirs==2.5.4
pluggy==1.0.0
pre-commit==3.6.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pydantic==1.10.13
pyparsing==3.0.9
pytest==7.2.0
python-dateutil==2.8.2
pytz==2022.7
PyYAML==6.0
requests==2.32.3
rfc3986==1.5.0
rsa==4.9
six==1.16.0
sniffio==1.3.0
starlette==0.37.2
toml==0.10.2
tomli==2.0.1
typing_extensions==4.11.0
urllib3==2.2.2
uvicorn==0.20.0
virtualenv==20.16.7
33 changes: 33 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,46 @@ def test_app():
yield client


@pytest.fixture
def disable_auth(monkeypatch):
"""
Disable the authentication requirement for the API to skip startup checks
(for when the tested route does not require authentication).
"""
monkeypatch.setattr("app.api.security.AUTH_ENABLED", False)


@pytest.fixture(scope="function")
def set_test_credentials(monkeypatch):
"""Set random username and password to avoid error from startup check for set credentials."""
monkeypatch.setenv(util.GRAPH_USERNAME.name, "SomeUser")
monkeypatch.setenv(util.GRAPH_PASSWORD.name, "SomePassword")


@pytest.fixture()
def mock_verify_token():
"""Mock a successful token verification that does not raise any exceptions."""

def _verify_token(token):
return None

return _verify_token


@pytest.fixture()
def set_mock_verify_token(monkeypatch, mock_verify_token):
"""Set the verify_token function to a mock that does not raise any exceptions."""
monkeypatch.setattr(
"app.api.routers.query.verify_token", mock_verify_token
)


@pytest.fixture()
def mock_auth_header() -> dict:
"""Create an authorization header with a mock token that is well-formed for testing purposes."""
return {"Authorization": "Bearer foo"}


@pytest.fixture()
def test_data():
"""Create valid aggregate response data for two toy datasets for testing."""
Expand Down
34 changes: 25 additions & 9 deletions tests/test_app_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from app.api import utility as util


def test_start_app_without_environment_vars_fails(test_app, monkeypatch):
@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_start_app_without_environment_vars_fails(
test_app, monkeypatch, disable_auth
):
"""Given non-existing username and password environment variables, raises an informative RuntimeError."""
monkeypatch.delenv(util.GRAPH_USERNAME.name, raising=False)
monkeypatch.delenv(util.GRAPH_PASSWORD.name, raising=False)
Expand All @@ -24,21 +27,27 @@ def test_start_app_without_environment_vars_fails(test_app, monkeypatch):
)


def test_app_with_invalid_environment_vars(test_app, monkeypatch):
"""Given invalid environment variables, returns a 401 status code."""
@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_app_with_invalid_environment_vars(
test_app, monkeypatch, mock_auth_header, set_mock_verify_token
):
"""Given invalid environment variables for the graph, returns a 401 status code."""
monkeypatch.setenv(util.GRAPH_USERNAME.name, "something")
monkeypatch.setenv(util.GRAPH_PASSWORD.name, "cool")

def mock_httpx_post(**kwargs):
return httpx.Response(status_code=401)

monkeypatch.setattr(httpx, "post", mock_httpx_post)
response = test_app.get("/query/")
response = test_app.get("/query/", headers=mock_auth_header)
assert response.status_code == 401


def test_app_with_unset_allowed_origins(
test_app, monkeypatch, set_test_credentials
test_app,
monkeypatch,
set_test_credentials,
disable_auth,
):
"""Tests that when the environment variable for allowed origins has not been set, a warning is raised and the app uses a default value."""
monkeypatch.delenv(util.ALLOWED_ORIGINS.name, raising=False)
Expand Down Expand Up @@ -90,6 +99,7 @@ def test_app_with_set_allowed_origins(
allowed_origins,
parsed_origins,
expectation,
disable_auth,
):
"""
Test that when the environment variable for allowed origins has been explicitly set, the app correctly parses it into a list
Expand All @@ -108,8 +118,11 @@ def test_app_with_set_allowed_origins(
)


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_stored_vocab_lookup_file_created_on_startup(
test_app, set_test_credentials
test_app,
set_test_credentials,
disable_auth,
):
"""Test that on startup, a non-empty temporary lookup file is created for term ID-label mappings for the locally stored SNOMED CT vocabulary."""
with test_app:
Expand All @@ -118,8 +131,9 @@ def test_stored_vocab_lookup_file_created_on_startup(
assert term_labels_path.stat().st_size > 0


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_external_vocab_is_fetched_on_startup(
test_app, monkeypatch, set_test_credentials
test_app, monkeypatch, set_test_credentials, disable_auth
):
"""
Tests that on startup, a GET request is made to the Cognitive Atlas API and that when the request succeeds,
Expand Down Expand Up @@ -160,8 +174,9 @@ def mock_httpx_get(**kwargs):
}


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_failed_vocab_fetching_on_startup_raises_warning(
test_app, monkeypatch, set_test_credentials
test_app, monkeypatch, set_test_credentials, disable_auth
):
"""
Tests that when a GET request to the Cognitive Atlas API has a non-success response code (e.g., due to service being unavailable),
Expand All @@ -186,8 +201,9 @@ def mock_httpx_get(**kwargs):
)


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
def test_network_error_on_startup_raises_warning(
test_app, monkeypatch, set_test_credentials
test_app, monkeypatch, set_test_credentials, disable_auth
):
"""
Tests that when a GET request to the Cognitive Atlas API fails due to a network error (i.e., while issuing the request),
Expand Down
9 changes: 5 additions & 4 deletions tests/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
from app.api import utility as util


def test_root(test_app, set_test_credentials):
def test_root(test_app):
"""Given a GET request to the root endpoint, Check for 200 status and expected content."""

with test_app:
response = test_app.get("/")
response = test_app.get("/")

assert response.status_code == 200
assert "Welcome to the Neurobagel REST API!" in response.text
assert '<a href="/docs">documentation</a>' in response.text


@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS")
@pytest.mark.parametrize(
"valid_data_element_URI",
["nb:Diagnosis", "nb:Assessment"],
Expand All @@ -28,6 +28,7 @@ def test_get_terms_valid_data_element_URI(
mock_successful_get_terms,
valid_data_element_URI,
monkeypatch,
disable_auth,
):
"""Given a valid data element URI, returns a 200 status code and a non-empty list of terms for that data element."""
monkeypatch.setattr(crud, "get_terms", mock_successful_get_terms)
Expand All @@ -54,7 +55,7 @@ def test_get_terms_invalid_data_element_URI(


def test_get_terms_for_attribute_with_vocab_lookup(
test_app, monkeypatch, set_test_credentials
test_app, monkeypatch, set_test_credentials, disable_auth
):
"""
Given a valid data element URI with a vocabulary lookup file available, returns prefixed term URIs and their human-readable labels (where found)
Expand Down
Loading