Skip to content

Commit

Permalink
fix: handling errors for non-existing endpoints (#4325)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR adds a callback controller to return a proper status code and
error message when a client requests a non-existing endpoint.

This change will help identify incompatibility problems when using new
clients over old server instances as suggested in #4073

Closes #4073

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [X] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

Calling unexisting endpoints locally 

**Checklist**

- [X] I followed the style guidelines of this project
- [X] I did a self-review of my code
- [X] My changes generate no new warnings
- [X] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
frascuchon and pre-commit-ci[bot] authored Nov 27, 2023
1 parent 3c78e5a commit 4c3c4d5
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ These are the section headers that we use:
- Fixed error in `ArgillaTrainer`, now we can train for `extractive_question_answering` using a validation sample ([#4204](https://github.com/argilla-io/argilla/pull/4204))
- Fixed error in `ArgillaTrainer`, when training for `sentence-similarity` it didn't work with a list of values per record ([#4211](https://github.com/argilla-io/argilla/pull/4211))
- Fixed error in the unification strategy for `RankingQuestion` ([#4295](https://github.com/argilla-io/argilla/pull/4295))
- Fixed error when requesting non-existing API endpoints. Closes [#4073](https://github.com/argilla-io/argilla/issues/4073) ([#4325](https://github.com/argilla-io/argilla/pull/4325))

### Changed

Expand Down
7 changes: 6 additions & 1 deletion src/argilla/server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
set the required security dependencies if api security is enabled
"""

from fastapi import APIRouter
from fastapi import APIRouter, HTTPException, Request

from argilla.server.apis.v0.handlers import (
datasets,
Expand Down Expand Up @@ -76,3 +76,8 @@
api_router.include_router(users_v1.router, prefix="/v1")
api_router.include_router(vectors_settings_v1.router, prefix="/v1")
api_router.include_router(workspaces_v1.router, prefix="/v1")


@api_router.route("/{_:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"], include_in_schema=False)
def endpoint_not_found_controller(request: Request):
raise HTTPException(status_code=404, detail=f"Endpoint {request.url.path!r} not found")
25 changes: 8 additions & 17 deletions src/argilla/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,7 @@ def _create_statics_folder(path_from):

app.mount(
"/",
RewriteStaticFiles(
directory=temp_statics,
html=True,
check_dir=False,
),
RewriteStaticFiles(directory=temp_statics, html=True, check_dir=False),
name="static",
)

Expand Down Expand Up @@ -241,7 +237,7 @@ async def log_default_user_warning_if_present():
_log_default_user_warning()


argilla_app = FastAPI(
app = FastAPI(
title="Argilla",
description="Argilla API",
# Disable default openapi configuration
Expand All @@ -252,30 +248,25 @@ async def log_default_user_warning_if_present():
)


@argilla_app.get("/docs", include_in_schema=False)
@app.get("/docs", include_in_schema=False)
async def redirect_docs():
return RedirectResponse(url=f"{settings.base_url}api/docs")


@argilla_app.get("/api", include_in_schema=False)
@app.get("/api", include_in_schema=False)
async def redirect_api():
return RedirectResponse(url=f"{settings.base_url}api/docs")


app = FastAPI(docs_url=None)
app.mount("/", argilla_app)

configure_app_logging(app)
configure_database(app)
configure_storage(app)
configure_telemetry(app)

for app_configure in [
configure_app_logging,
configure_database,
configure_storage,
configure_telemetry,
configure_middleware,
configure_api_exceptions,
configure_app_security,
configure_api_router,
configure_app_statics,
]:
app_configure(argilla_app)
app_configure(app)
6 changes: 3 additions & 3 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from argilla.datasets.__init__ import configure_dataset
from argilla.server.database import get_async_db
from argilla.server.models import User, UserRole, Workspace
from argilla.server.server import app, argilla_app
from argilla.server.server import app
from argilla.server.settings import settings
from argilla.utils import telemetry
from argilla.utils.telemetry import TelemetryClient
Expand Down Expand Up @@ -134,13 +134,13 @@ async def override_get_async_db():

mocker.patch("argilla.server.server._get_db_wrapper", wraps=contextlib.asynccontextmanager(override_get_async_db))

argilla_app.dependency_overrides[get_async_db] = override_get_async_db
app.dependency_overrides[get_async_db] = override_get_async_db

raise_server_exceptions = request.param if hasattr(request, "param") else False
with TestClient(app, raise_server_exceptions=raise_server_exceptions) as client:
yield client

argilla_app.dependency_overrides.clear()
app.dependency_overrides.clear()


@pytest.fixture(scope="session")
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/server/api/test_not_found_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from httpx import AsyncClient


@pytest.mark.asyncio
@pytest.mark.parametrize("http_method", ["GET", "POST", "PUT", "DELETE", "PATCH"])
async def test_route_not_found_response(async_client: AsyncClient, http_method: str):
response = await async_client.request(method=http_method, url="/api/not/found/route")

assert response.status_code == 404
assert response.json() == {"detail": "Endpoint '/api/not/found/route' not found"}
10 changes: 5 additions & 5 deletions tests/unit/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from argilla.server.database import get_async_db
from argilla.server.models import User, UserRole, Workspace
from argilla.server.search_engine import OpenSearchEngine, SearchEngine, get_search_engine
from argilla.server.server import argilla_app
from argilla.server.server import app
from argilla.server.services.datasets import DatasetsService
from argilla.server.settings import settings
from argilla.utils import telemetry
Expand Down Expand Up @@ -88,13 +88,13 @@ async def override_get_search_engine():

mocker.patch("argilla.server.server._get_db_wrapper", wraps=contextlib.asynccontextmanager(override_get_async_db))

argilla_app.dependency_overrides[get_async_db] = override_get_async_db
argilla_app.dependency_overrides[get_search_engine] = override_get_search_engine
app.dependency_overrides[get_async_db] = override_get_async_db
app.dependency_overrides[get_search_engine] = override_get_search_engine

async with AsyncClient(app=argilla_app, base_url="http://testserver") as async_client:
async with AsyncClient(app=app, base_url="http://testserver") as async_client:
yield async_client

argilla_app.dependency_overrides.clear()
app.dependency_overrides.clear()


@pytest.fixture(autouse=True)
Expand Down

0 comments on commit 4c3c4d5

Please sign in to comment.