diff --git a/Dockerfile b/Dockerfile
index b8a17d6..ef2b5d4 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -11,4 +11,4 @@ COPY ./app /usr/src/app
# NB_API_PORT, representing the port on which the API will be exposed,
# is an environment variable that will always have a default value of 8000 when building the image
# but can be overridden when running the container.
-ENTRYPOINT uvicorn app.main:app --proxy-headers --host 0.0.0.0 --port ${NB_API_PORT:-8000}
+ENTRYPOINT uvicorn app.main:app --proxy-headers --forwarded-allow-ips=* --host 0.0.0.0 --port ${NB_API_PORT:-8000}
diff --git a/app/api/crud.py b/app/api/crud.py
index 636b431..b22237e 100644
--- a/app/api/crud.py
+++ b/app/api/crud.py
@@ -103,7 +103,7 @@ async def get(
params["image_modal"] = image_modal
tasks = [
- util.send_get_request(node_url + "query/", params)
+ util.send_get_request(node_url + "query", params)
for node_url in node_urls
]
responses = await asyncio.gather(*tasks, return_exceptions=True)
diff --git a/app/api/routers/nodes.py b/app/api/routers/nodes.py
index 15febb8..a088ea7 100644
--- a/app/api/routers/nodes.py
+++ b/app/api/routers/nodes.py
@@ -5,7 +5,7 @@
router = APIRouter(prefix="/nodes", tags=["nodes"])
-@router.get("/")
+@router.get("")
async def get_nodes():
"""Returns a dict of all available nodes apis where key is node URL and value is node name."""
return [
diff --git a/app/api/routers/query.py b/app/api/routers/query.py
index d6f0e5e..15f1e4b 100644
--- a/app/api/routers/query.py
+++ b/app/api/routers/query.py
@@ -34,7 +34,7 @@
# TODO: if our response model for fully successful vs. not fully successful responses grows more complex in the future,
# consider additionally using https://fastapi.tiangolo.com/advanced/additional-responses/#additional-response-with-model to document
# example responses for different status codes in the OpenAPI docs (less relevant for now since there is only one response model).
-@router.get("/", response_model=CombinedQueryResponse)
+@router.get("", response_model=CombinedQueryResponse)
async def get_query(
response: Response,
query: QueryModel = Depends(QueryModel),
diff --git a/app/main.py b/app/main.py
index 6e6de54..60432db 100644
--- a/app/main.py
+++ b/app/main.py
@@ -38,6 +38,7 @@ async def lifespan(app: FastAPI):
docs_url=None,
redoc_url=None,
lifespan=lifespan,
+ redirect_slashes=False,
)
app.add_middleware(
diff --git a/requirements.txt b/requirements.txt
index ed881d1..6c7adc9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,7 +9,7 @@ colorama==0.4.6
coverage==7.0.0
distlib==0.3.6
exceptiongroup==1.0.4
-fastapi==0.95.2
+fastapi==0.110.1
filelock==3.8.0
google-auth==2.32.0
h11==0.14.0
@@ -44,10 +44,10 @@ rpds-py==0.13.2
rsa==4.9
six==1.16.0
sniffio==1.3.0
-starlette==0.27.0
+starlette==0.37.2
toml==0.10.2
tomli==2.0.1
-typing_extensions==4.4.0
+typing_extensions==4.12.2
urllib3==2.2.0
uvicorn==0.20.0
virtualenv==20.16.7
diff --git a/tests/conftest.py b/tests/conftest.py
index 3d67611..9f9cba4 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -12,7 +12,13 @@ def test_app():
yield client
-@pytest.fixture
+@pytest.fixture()
+def enable_auth(monkeypatch):
+ """Enable the authentication requirement for the API."""
+ monkeypatch.setattr("app.api.security.AUTH_ENABLED", True)
+
+
+@pytest.fixture()
def disable_auth(monkeypatch):
"""
Disable the authentication requirement for the API to skip startup checks
@@ -62,3 +68,21 @@ async def _mock_httpx_get_with_connect_error(self, **kwargs):
raise httpx.ConnectError("Some connection error")
return _mock_httpx_get_with_connect_error
+
+
+@pytest.fixture()
+def mocked_single_matching_dataset_result():
+ """Valid aggregate query result for a single matching dataset."""
+ return {
+ "dataset_uuid": "http://neurobagel.org/vocab/12345",
+ "dataset_name": "QPN",
+ "dataset_portal_uri": "https://rpq-qpn.ca/en/researchers-section/databases/",
+ "dataset_total_subjects": 200,
+ "num_matching_subjects": 5,
+ "records_protected": True,
+ "subject_data": "protected",
+ "image_modals": [
+ "http://purl.org/nidash/nidm#T1Weighted",
+ "http://purl.org/nidash/nidm#T2Weighted",
+ ],
+ }
diff --git a/tests/test_attributes.py b/tests/test_attributes.py
index 613d257..056e9e3 100644
--- a/tests/test_attributes.py
+++ b/tests/test_attributes.py
@@ -2,22 +2,6 @@
from fastapi import status
-def test_root(test_app, set_valid_test_federation_nodes):
- """Given a GET request to the root endpoint, Check for 200 status and expected content."""
-
- response = test_app.get("/")
-
- assert response.status_code == status.HTTP_200_OK
- assert all(
- substring in response.text
- for substring in [
- "Welcome to",
- "Neurobagel",
- 'documentation',
- ]
- )
-
-
def test_partially_failed_terms_fetching_handled_gracefully(
test_app, monkeypatch, set_valid_test_federation_nodes, caplog
):
diff --git a/tests/test_nodes.py b/tests/test_nodes.py
index 56eb42d..ed159a6 100644
--- a/tests/test_nodes.py
+++ b/tests/test_nodes.py
@@ -41,7 +41,7 @@ def mock_httpx_get(**kwargs):
monkeypatch.setattr(httpx, "get", mock_httpx_get)
with test_app:
- response = test_app.get("/nodes/")
+ response = test_app.get("/nodes")
assert util.FEDERATION_NODES == {
"https://firstpublicnode.org/": "First Public Node",
"https://secondpublicnode.org/": "Second Public Node",
@@ -77,7 +77,7 @@ def mock_httpx_get(**kwargs):
monkeypatch.setattr(httpx, "get", mock_httpx_get)
with test_app:
- response = test_app.get("/nodes/")
+ response = test_app.get("/nodes")
assert util.FEDERATION_NODES == {
"https://mylocalnode.org/": "Local Node"
}
@@ -123,7 +123,7 @@ def mock_httpx_get(**kwargs):
with pytest.warns(UserWarning) as w:
with test_app:
- response = test_app.get("/nodes/")
+ response = test_app.get("/nodes")
assert util.FEDERATION_NODES == {
"https://firstpublicnode.org/": "First Public Node",
"https://secondpublicnode.org/": "Second Public Node",
diff --git a/tests/test_query.py b/tests/test_query.py
index 7693592..44f9dd9 100644
--- a/tests/test_query.py
+++ b/tests/test_query.py
@@ -12,24 +12,6 @@ def mock_token():
return "Bearer foo"
-@pytest.fixture()
-def mocked_single_matching_dataset_result():
- """Valid aggregate query result for a single matching dataset."""
- return {
- "dataset_uuid": "http://neurobagel.org/vocab/12345",
- "dataset_name": "QPN",
- "dataset_portal_uri": "https://rpq-qpn.ca/en/researchers-section/databases/",
- "dataset_total_subjects": 200,
- "num_matching_subjects": 5,
- "records_protected": True,
- "subject_data": "protected",
- "image_modals": [
- "http://purl.org/nidash/nidm#T1Weighted",
- "http://purl.org/nidash/nidm#T2Weighted",
- ],
- }
-
-
def test_partial_node_failure_responses_handled_gracefully(
monkeypatch,
test_app,
@@ -47,7 +29,7 @@ def test_partial_node_failure_responses_handled_gracefully(
async def mock_httpx_get(self, **kwargs):
# The self parameter is necessary to match the signature of the method being mocked,
# which is a class method of the httpx.AsyncClient class (see https://www.python-httpx.org/api/#asyncclient).
- if kwargs["url"] == "https://firstpublicnode.org/query/":
+ if kwargs["url"] == "https://firstpublicnode.org/query":
return httpx.Response(
status_code=200, json=[mocked_single_matching_dataset_result]
)
@@ -59,7 +41,7 @@ async def mock_httpx_get(self, **kwargs):
monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)
response = test_app.get(
- "/query/",
+ "/query",
headers={"Authorization": mock_token},
)
@@ -127,7 +109,7 @@ def test_partial_node_request_failures_handled_gracefully(
"""
async def mock_httpx_get(self, **kwargs):
- if kwargs["url"] == "https://firstpublicnode.org/query/":
+ if kwargs["url"] == "https://firstpublicnode.org/query":
return httpx.Response(
status_code=200, json=[mocked_single_matching_dataset_result]
)
@@ -137,7 +119,7 @@ async def mock_httpx_get(self, **kwargs):
monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)
response = test_app.get(
- "/query/",
+ "/query",
headers={"Authorization": mock_token},
)
@@ -183,7 +165,7 @@ def test_all_nodes_failure_handled_gracefully(
)
response = test_app.get(
- "/query/",
+ "/query",
headers={"Authorization": mock_token},
)
@@ -225,7 +207,7 @@ async def mock_httpx_get(self, **kwargs):
monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)
response = test_app.get(
- "/query/",
+ "/query",
headers={"Authorization": mock_token},
)
@@ -256,6 +238,6 @@ async def mock_httpx_get(self, **kwargs):
monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)
- response = test_app.get("/query/")
+ response = test_app.get("/query")
assert response.status_code == status.HTTP_200_OK
diff --git a/tests/test_routing.py b/tests/test_routing.py
new file mode 100644
index 0000000..471e864
--- /dev/null
+++ b/tests/test_routing.py
@@ -0,0 +1,63 @@
+import httpx
+import pytest
+from fastapi import status
+
+
+@pytest.mark.parametrize(
+ "root_path",
+ ["/", ""],
+)
+def test_root(test_app, set_valid_test_federation_nodes, root_path):
+ """Given a GET request to the root endpoint, Check for 200 status and expected content."""
+
+ response = test_app.get(root_path, follow_redirects=False)
+
+ assert response.status_code == status.HTTP_200_OK
+ assert all(
+ substring in response.text
+ for substring in [
+ "Welcome to",
+ "Neurobagel",
+ 'documentation',
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "valid_route",
+ ["/query", "/query?min_age=20", "/nodes"],
+)
+def test_request_without_trailing_slash_not_redirected(
+ test_app,
+ monkeypatch,
+ set_valid_test_federation_nodes,
+ mocked_single_matching_dataset_result,
+ disable_auth,
+ valid_route,
+):
+ """Test that a request to a route without a / is not redirected to have a trailing slash."""
+
+ async def mock_httpx_get(self, **kwargs):
+ return httpx.Response(
+ status_code=200, json=[mocked_single_matching_dataset_result]
+ )
+
+ monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get)
+
+ response = test_app.get(valid_route, follow_redirects=False)
+ assert response.status_code == status.HTTP_200_OK
+
+
+@pytest.mark.parametrize(
+ "invalid_route",
+ ["/query/", "/query/?min_age=20", "/nodes/", "/attributes/nb:SomeClass/"],
+)
+def test_request_including_trailing_slash_fails(
+ test_app, disable_auth, invalid_route
+):
+ """
+ Test that a request to routes including a trailing slash, where none is expected,
+ is *not* redirected to exclude the slash, and returns a 404.
+ """
+ response = test_app.get(invalid_route)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
diff --git a/tests/test_security.py b/tests/test_security.py
index b657168..eb1e326 100644
--- a/tests/test_security.py
+++ b/tests/test_security.py
@@ -5,14 +5,13 @@
def test_missing_client_id_raises_error_when_auth_enabled(
- monkeypatch, test_app
+ monkeypatch, test_app, enable_auth
):
"""Test that a missing client ID raises an error on startup when authentication is enabled."""
# We're using what should be default values of CLIENT_ID and AUTH_ENABLED here
# (if the corresponding environment variables are unset),
# but we set the values explicitly here for clarity
monkeypatch.setattr("app.api.security.CLIENT_ID", None)
- monkeypatch.setattr("app.api.security.AUTH_ENABLED", True)
with pytest.raises(ValueError) as exc_info:
with test_app:
@@ -52,12 +51,20 @@ def test_invalid_token_raises_error(invalid_token):
[{}, {"Authorization": ""}, {"badheader": "badvalue"}],
)
def test_query_with_malformed_auth_header_fails(
- test_app, set_mock_verify_token, invalid_auth_header
+ test_app,
+ set_mock_verify_token,
+ enable_auth,
+ invalid_auth_header,
+ monkeypatch,
):
- """Test that a request to the /query route with a missing or malformed authorization header, fails ."""
+ """
+ Test that when authentication is enabled, a request to the /query route with a
+ missing or malformed authorization header fails.
+ """
+ monkeypatch.setattr("app.api.security.CLIENT_ID", "foo.id")
response = test_app.get(
- "/query/",
+ "/query",
headers=invalid_auth_header,
)