diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a8c999704..d3be59f63f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ These are the section headers that we use: - Fixed error in `ArgillaTrainer`, now we can train for `extractive_question_answering` using a validation sample ([#4204](https://github.com/argilla-io/argilla/pull/4204)) - Fixed error in `ArgillaTrainer`, when training for `sentence-similarity` it didn't work with a list of values per record ([#4211](https://github.com/argilla-io/argilla/pull/4211)) - Fixed error in the unification strategy for `RankingQuestion` ([#4295](https://github.com/argilla-io/argilla/pull/4295)) +- Fixed error when requesting non-existing API endpoints. Closes [#4073](https://github.com/argilla-io/argilla/issues/4073) ([#4325](https://github.com/argilla-io/argilla/pull/4325)) ### Changed diff --git a/src/argilla/server/routes.py b/src/argilla/server/routes.py index 77b6a6dac2..ff24a42888 100644 --- a/src/argilla/server/routes.py +++ b/src/argilla/server/routes.py @@ -18,7 +18,7 @@ set the required security dependencies if api security is enabled """ -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException, Request from argilla.server.apis.v0.handlers import ( datasets, @@ -76,3 +76,8 @@ api_router.include_router(users_v1.router, prefix="/v1") api_router.include_router(vectors_settings_v1.router, prefix="/v1") api_router.include_router(workspaces_v1.router, prefix="/v1") + + +@api_router.route("/{_:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"], include_in_schema=False) +def endpoint_not_found_controller(request: Request): + raise HTTPException(status_code=404, detail=f"Endpoint {request.url.path!r} not found") diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py index 3e0b86c24a..b837c2e102 100644 --- a/src/argilla/server/server.py +++ b/src/argilla/server/server.py @@ -139,11 +139,7 @@ def _create_statics_folder(path_from): app.mount( "/", - RewriteStaticFiles( - directory=temp_statics, - html=True, - check_dir=False, - ), + RewriteStaticFiles(directory=temp_statics, html=True, check_dir=False), name="static", ) @@ -241,7 +237,7 @@ async def log_default_user_warning_if_present(): _log_default_user_warning() -argilla_app = FastAPI( +app = FastAPI( title="Argilla", description="Argilla API", # Disable default openapi configuration @@ -252,30 +248,25 @@ async def log_default_user_warning_if_present(): ) -@argilla_app.get("/docs", include_in_schema=False) +@app.get("/docs", include_in_schema=False) async def redirect_docs(): return RedirectResponse(url=f"{settings.base_url}api/docs") -@argilla_app.get("/api", include_in_schema=False) +@app.get("/api", include_in_schema=False) async def redirect_api(): return RedirectResponse(url=f"{settings.base_url}api/docs") -app = FastAPI(docs_url=None) -app.mount("/", argilla_app) - -configure_app_logging(app) -configure_database(app) -configure_storage(app) -configure_telemetry(app) - for app_configure in [ configure_app_logging, + configure_database, + configure_storage, + configure_telemetry, configure_middleware, configure_api_exceptions, configure_app_security, configure_api_router, configure_app_statics, ]: - app_configure(argilla_app) + app_configure(app) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f712aa81dd..d9eec3c220 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -30,7 +30,7 @@ from argilla.datasets.__init__ import configure_dataset from argilla.server.database import get_async_db from argilla.server.models import User, UserRole, Workspace -from argilla.server.server import app, argilla_app +from argilla.server.server import app from argilla.server.settings import settings from argilla.utils import telemetry from argilla.utils.telemetry import TelemetryClient @@ -134,13 +134,13 @@ async def override_get_async_db(): mocker.patch("argilla.server.server._get_db_wrapper", wraps=contextlib.asynccontextmanager(override_get_async_db)) - argilla_app.dependency_overrides[get_async_db] = override_get_async_db + app.dependency_overrides[get_async_db] = override_get_async_db raise_server_exceptions = request.param if hasattr(request, "param") else False with TestClient(app, raise_server_exceptions=raise_server_exceptions) as client: yield client - argilla_app.dependency_overrides.clear() + app.dependency_overrides.clear() @pytest.fixture(scope="session") diff --git a/tests/unit/server/api/test_not_found_routes.py b/tests/unit/server/api/test_not_found_routes.py new file mode 100644 index 0000000000..ba4c6f1ced --- /dev/null +++ b/tests/unit/server/api/test_not_found_routes.py @@ -0,0 +1,25 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +@pytest.mark.parametrize("http_method", ["GET", "POST", "PUT", "DELETE", "PATCH"]) +async def test_route_not_found_response(async_client: AsyncClient, http_method: str): + response = await async_client.request(method=http_method, url="/api/not/found/route") + + assert response.status_code == 404 + assert response.json() == {"detail": "Endpoint '/api/not/found/route' not found"} diff --git a/tests/unit/server/conftest.py b/tests/unit/server/conftest.py index f8717e07d5..373efef48e 100644 --- a/tests/unit/server/conftest.py +++ b/tests/unit/server/conftest.py @@ -24,7 +24,7 @@ from argilla.server.database import get_async_db from argilla.server.models import User, UserRole, Workspace from argilla.server.search_engine import OpenSearchEngine, SearchEngine, get_search_engine -from argilla.server.server import argilla_app +from argilla.server.server import app from argilla.server.services.datasets import DatasetsService from argilla.server.settings import settings from argilla.utils import telemetry @@ -88,13 +88,13 @@ async def override_get_search_engine(): mocker.patch("argilla.server.server._get_db_wrapper", wraps=contextlib.asynccontextmanager(override_get_async_db)) - argilla_app.dependency_overrides[get_async_db] = override_get_async_db - argilla_app.dependency_overrides[get_search_engine] = override_get_search_engine + app.dependency_overrides[get_async_db] = override_get_async_db + app.dependency_overrides[get_search_engine] = override_get_search_engine - async with AsyncClient(app=argilla_app, base_url="http://testserver") as async_client: + async with AsyncClient(app=app, base_url="http://testserver") as async_client: yield async_client - argilla_app.dependency_overrides.clear() + app.dependency_overrides.clear() @pytest.fixture(autouse=True)