diff --git a/Dockerfile b/Dockerfile index b8a17d6..ef2b5d4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,4 +11,4 @@ COPY ./app /usr/src/app # NB_API_PORT, representing the port on which the API will be exposed, # is an environment variable that will always have a default value of 8000 when building the image # but can be overridden when running the container. -ENTRYPOINT uvicorn app.main:app --proxy-headers --host 0.0.0.0 --port ${NB_API_PORT:-8000} +ENTRYPOINT uvicorn app.main:app --proxy-headers --forwarded-allow-ips=* --host 0.0.0.0 --port ${NB_API_PORT:-8000} diff --git a/app/api/crud.py b/app/api/crud.py index 636b431..b22237e 100644 --- a/app/api/crud.py +++ b/app/api/crud.py @@ -103,7 +103,7 @@ async def get( params["image_modal"] = image_modal tasks = [ - util.send_get_request(node_url + "query/", params) + util.send_get_request(node_url + "query", params) for node_url in node_urls ] responses = await asyncio.gather(*tasks, return_exceptions=True) diff --git a/app/api/routers/nodes.py b/app/api/routers/nodes.py index 15febb8..a088ea7 100644 --- a/app/api/routers/nodes.py +++ b/app/api/routers/nodes.py @@ -5,7 +5,7 @@ router = APIRouter(prefix="/nodes", tags=["nodes"]) -@router.get("/") +@router.get("") async def get_nodes(): """Returns a dict of all available nodes apis where key is node URL and value is node name.""" return [ diff --git a/app/api/routers/query.py b/app/api/routers/query.py index d6f0e5e..15f1e4b 100644 --- a/app/api/routers/query.py +++ b/app/api/routers/query.py @@ -34,7 +34,7 @@ # TODO: if our response model for fully successful vs. not fully successful responses grows more complex in the future, # consider additionally using https://fastapi.tiangolo.com/advanced/additional-responses/#additional-response-with-model to document # example responses for different status codes in the OpenAPI docs (less relevant for now since there is only one response model). -@router.get("/", response_model=CombinedQueryResponse) +@router.get("", response_model=CombinedQueryResponse) async def get_query( response: Response, query: QueryModel = Depends(QueryModel), diff --git a/app/main.py b/app/main.py index 6e6de54..60432db 100644 --- a/app/main.py +++ b/app/main.py @@ -38,6 +38,7 @@ async def lifespan(app: FastAPI): docs_url=None, redoc_url=None, lifespan=lifespan, + redirect_slashes=False, ) app.add_middleware( diff --git a/requirements.txt b/requirements.txt index ed881d1..6c7adc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ colorama==0.4.6 coverage==7.0.0 distlib==0.3.6 exceptiongroup==1.0.4 -fastapi==0.95.2 +fastapi==0.110.1 filelock==3.8.0 google-auth==2.32.0 h11==0.14.0 @@ -44,10 +44,10 @@ rpds-py==0.13.2 rsa==4.9 six==1.16.0 sniffio==1.3.0 -starlette==0.27.0 +starlette==0.37.2 toml==0.10.2 tomli==2.0.1 -typing_extensions==4.4.0 +typing_extensions==4.12.2 urllib3==2.2.0 uvicorn==0.20.0 virtualenv==20.16.7 diff --git a/tests/conftest.py b/tests/conftest.py index 3d67611..9f9cba4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,13 @@ def test_app(): yield client -@pytest.fixture +@pytest.fixture() +def enable_auth(monkeypatch): + """Enable the authentication requirement for the API.""" + monkeypatch.setattr("app.api.security.AUTH_ENABLED", True) + + +@pytest.fixture() def disable_auth(monkeypatch): """ Disable the authentication requirement for the API to skip startup checks @@ -62,3 +68,21 @@ async def _mock_httpx_get_with_connect_error(self, **kwargs): raise httpx.ConnectError("Some connection error") return _mock_httpx_get_with_connect_error + + +@pytest.fixture() +def mocked_single_matching_dataset_result(): + """Valid aggregate query result for a single matching dataset.""" + return { + "dataset_uuid": "http://neurobagel.org/vocab/12345", + "dataset_name": "QPN", + "dataset_portal_uri": "https://rpq-qpn.ca/en/researchers-section/databases/", + "dataset_total_subjects": 200, + "num_matching_subjects": 5, + "records_protected": True, + "subject_data": "protected", + "image_modals": [ + "http://purl.org/nidash/nidm#T1Weighted", + "http://purl.org/nidash/nidm#T2Weighted", + ], + } diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 613d257..056e9e3 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -2,22 +2,6 @@ from fastapi import status -def test_root(test_app, set_valid_test_federation_nodes): - """Given a GET request to the root endpoint, Check for 200 status and expected content.""" - - response = test_app.get("/") - - assert response.status_code == status.HTTP_200_OK - assert all( - substring in response.text - for substring in [ - "Welcome to", - "Neurobagel", - 'documentation', - ] - ) - - def test_partially_failed_terms_fetching_handled_gracefully( test_app, monkeypatch, set_valid_test_federation_nodes, caplog ): diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 56eb42d..ed159a6 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -41,7 +41,7 @@ def mock_httpx_get(**kwargs): monkeypatch.setattr(httpx, "get", mock_httpx_get) with test_app: - response = test_app.get("/nodes/") + response = test_app.get("/nodes") assert util.FEDERATION_NODES == { "https://firstpublicnode.org/": "First Public Node", "https://secondpublicnode.org/": "Second Public Node", @@ -77,7 +77,7 @@ def mock_httpx_get(**kwargs): monkeypatch.setattr(httpx, "get", mock_httpx_get) with test_app: - response = test_app.get("/nodes/") + response = test_app.get("/nodes") assert util.FEDERATION_NODES == { "https://mylocalnode.org/": "Local Node" } @@ -123,7 +123,7 @@ def mock_httpx_get(**kwargs): with pytest.warns(UserWarning) as w: with test_app: - response = test_app.get("/nodes/") + response = test_app.get("/nodes") assert util.FEDERATION_NODES == { "https://firstpublicnode.org/": "First Public Node", "https://secondpublicnode.org/": "Second Public Node", diff --git a/tests/test_query.py b/tests/test_query.py index 7693592..44f9dd9 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -12,24 +12,6 @@ def mock_token(): return "Bearer foo" -@pytest.fixture() -def mocked_single_matching_dataset_result(): - """Valid aggregate query result for a single matching dataset.""" - return { - "dataset_uuid": "http://neurobagel.org/vocab/12345", - "dataset_name": "QPN", - "dataset_portal_uri": "https://rpq-qpn.ca/en/researchers-section/databases/", - "dataset_total_subjects": 200, - "num_matching_subjects": 5, - "records_protected": True, - "subject_data": "protected", - "image_modals": [ - "http://purl.org/nidash/nidm#T1Weighted", - "http://purl.org/nidash/nidm#T2Weighted", - ], - } - - def test_partial_node_failure_responses_handled_gracefully( monkeypatch, test_app, @@ -47,7 +29,7 @@ def test_partial_node_failure_responses_handled_gracefully( async def mock_httpx_get(self, **kwargs): # The self parameter is necessary to match the signature of the method being mocked, # which is a class method of the httpx.AsyncClient class (see https://www.python-httpx.org/api/#asyncclient). - if kwargs["url"] == "https://firstpublicnode.org/query/": + if kwargs["url"] == "https://firstpublicnode.org/query": return httpx.Response( status_code=200, json=[mocked_single_matching_dataset_result] ) @@ -59,7 +41,7 @@ async def mock_httpx_get(self, **kwargs): monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) response = test_app.get( - "/query/", + "/query", headers={"Authorization": mock_token}, ) @@ -127,7 +109,7 @@ def test_partial_node_request_failures_handled_gracefully( """ async def mock_httpx_get(self, **kwargs): - if kwargs["url"] == "https://firstpublicnode.org/query/": + if kwargs["url"] == "https://firstpublicnode.org/query": return httpx.Response( status_code=200, json=[mocked_single_matching_dataset_result] ) @@ -137,7 +119,7 @@ async def mock_httpx_get(self, **kwargs): monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) response = test_app.get( - "/query/", + "/query", headers={"Authorization": mock_token}, ) @@ -183,7 +165,7 @@ def test_all_nodes_failure_handled_gracefully( ) response = test_app.get( - "/query/", + "/query", headers={"Authorization": mock_token}, ) @@ -225,7 +207,7 @@ async def mock_httpx_get(self, **kwargs): monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) response = test_app.get( - "/query/", + "/query", headers={"Authorization": mock_token}, ) @@ -256,6 +238,6 @@ async def mock_httpx_get(self, **kwargs): monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) - response = test_app.get("/query/") + response = test_app.get("/query") assert response.status_code == status.HTTP_200_OK diff --git a/tests/test_routing.py b/tests/test_routing.py new file mode 100644 index 0000000..471e864 --- /dev/null +++ b/tests/test_routing.py @@ -0,0 +1,63 @@ +import httpx +import pytest +from fastapi import status + + +@pytest.mark.parametrize( + "root_path", + ["/", ""], +) +def test_root(test_app, set_valid_test_federation_nodes, root_path): + """Given a GET request to the root endpoint, Check for 200 status and expected content.""" + + response = test_app.get(root_path, follow_redirects=False) + + assert response.status_code == status.HTTP_200_OK + assert all( + substring in response.text + for substring in [ + "Welcome to", + "Neurobagel", + 'documentation', + ] + ) + + +@pytest.mark.parametrize( + "valid_route", + ["/query", "/query?min_age=20", "/nodes"], +) +def test_request_without_trailing_slash_not_redirected( + test_app, + monkeypatch, + set_valid_test_federation_nodes, + mocked_single_matching_dataset_result, + disable_auth, + valid_route, +): + """Test that a request to a route without a / is not redirected to have a trailing slash.""" + + async def mock_httpx_get(self, **kwargs): + return httpx.Response( + status_code=200, json=[mocked_single_matching_dataset_result] + ) + + monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) + + response = test_app.get(valid_route, follow_redirects=False) + assert response.status_code == status.HTTP_200_OK + + +@pytest.mark.parametrize( + "invalid_route", + ["/query/", "/query/?min_age=20", "/nodes/", "/attributes/nb:SomeClass/"], +) +def test_request_including_trailing_slash_fails( + test_app, disable_auth, invalid_route +): + """ + Test that a request to routes including a trailing slash, where none is expected, + is *not* redirected to exclude the slash, and returns a 404. + """ + response = test_app.get(invalid_route) + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/test_security.py b/tests/test_security.py index b657168..eb1e326 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -5,14 +5,13 @@ def test_missing_client_id_raises_error_when_auth_enabled( - monkeypatch, test_app + monkeypatch, test_app, enable_auth ): """Test that a missing client ID raises an error on startup when authentication is enabled.""" # We're using what should be default values of CLIENT_ID and AUTH_ENABLED here # (if the corresponding environment variables are unset), # but we set the values explicitly here for clarity monkeypatch.setattr("app.api.security.CLIENT_ID", None) - monkeypatch.setattr("app.api.security.AUTH_ENABLED", True) with pytest.raises(ValueError) as exc_info: with test_app: @@ -52,12 +51,20 @@ def test_invalid_token_raises_error(invalid_token): [{}, {"Authorization": ""}, {"badheader": "badvalue"}], ) def test_query_with_malformed_auth_header_fails( - test_app, set_mock_verify_token, invalid_auth_header + test_app, + set_mock_verify_token, + enable_auth, + invalid_auth_header, + monkeypatch, ): - """Test that a request to the /query route with a missing or malformed authorization header, fails .""" + """ + Test that when authentication is enabled, a request to the /query route with a + missing or malformed authorization header fails. + """ + monkeypatch.setattr("app.api.security.CLIENT_ID", "foo.id") response = test_app.get( - "/query/", + "/query", headers=invalid_auth_header, )