From 95cd02729bbd259c924ec1e68ece236ad4baa60f Mon Sep 17 00:00:00 2001 From: Bruno Lenzi Date: Mon, 7 Aug 2023 20:08:47 +0200 Subject: [PATCH 1/2] feat: add limit and offset parameters to routes that fetch lists from db - add limit and offset parameters with descriptions, to all routes that fetch lists from db (fix #243, #250) - use crud.fetch_all in all the routes above, to limit the number of items returned for non-admin users (fix #276) - forbid /installation/site-devices to access site outside of group, instead of filtering devices - improve test_installations.py (test_get_active_devices_on_site) - Dockerfile-dev: update FROM image name --- src/Dockerfile-dev | 2 +- src/app/api/crud/base.py | 12 +++- src/app/api/endpoints/accesses.py | 13 ++-- src/app/api/endpoints/alerts.py | 55 +++++++--------- src/app/api/endpoints/devices.py | 33 ++++++---- src/app/api/endpoints/events.py | 87 ++++++++++++++------------ src/app/api/endpoints/groups.py | 12 ++-- src/app/api/endpoints/installations.py | 68 +++++++++++--------- src/app/api/endpoints/media.py | 24 ++++--- src/app/api/endpoints/notifications.py | 12 ++-- src/app/api/endpoints/recipients.py | 20 ++++-- src/app/api/endpoints/sites.py | 20 +++--- src/app/api/endpoints/users.py | 24 ++++--- src/app/api/endpoints/webhooks.py | 13 ++-- src/tests/routes/test_groups.py | 26 ++++++-- src/tests/routes/test_installations.py | 3 +- 16 files changed, 254 insertions(+), 170 deletions(-) diff --git a/src/Dockerfile-dev b/src/Dockerfile-dev index 10ce2069..3d13466a 100644 --- a/src/Dockerfile-dev +++ b/src/Dockerfile-dev @@ -1,4 +1,4 @@ -FROM pyroapi:python3.8-alpine3.10 +FROM pyronear/pyro-api:python3.8-alpine3.10 # copy requirements file COPY requirements-dev.txt requirements-dev.txt diff --git a/src/app/api/crud/base.py b/src/app/api/crud/base.py index 6a19804f..5a3f26b4 100644 --- a/src/app/api/crud/base.py +++ b/src/app/api/crud/base.py @@ -8,6 +8,8 @@ from fastapi import HTTPException, Path, status from pydantic import BaseModel from sqlalchemy import Table +from sqlalchemy.orm import Query +from sqlalchemy.sql import Select from app.db import database @@ -41,8 +43,11 @@ async def fetch_all( query_filters: Optional[Dict[str, Any]] = None, exclusions: Optional[Dict[str, Any]] = None, limit: int = 50, + offset: Optional[int] = None, + query: Optional[Select] = None, ) -> List[Mapping[str, Any]]: - query = table.select().order_by(table.c.id.desc()) + if query is None: + query = table.select() if isinstance(query_filters, dict): for key, value in query_filters.items(): query = query.where(getattr(table.c, key) == value) @@ -50,7 +55,10 @@ async def fetch_all( if isinstance(exclusions, dict): for key, value in exclusions.items(): query = query.where(getattr(table.c, key) != value) - return (await database.fetch_all(query=query.limit(limit)))[::-1] + query = query.order_by(table.c.id.desc()).limit(limit).offset(offset) + if isinstance(query, Query): + return [item.__dict__ for item in query[::-1]] + return (await database.fetch_all(query=query))[::-1] async def fetch_one(table: Table, query_filters: Dict[str, Any]) -> Optional[Mapping[str, Any]]: diff --git a/src/app/api/endpoints/accesses.py b/src/app/api/endpoints/accesses.py index 0d861b91..79cfffe2 100644 --- a/src/app/api/endpoints/accesses.py +++ b/src/app/api/endpoints/accesses.py @@ -3,9 +3,10 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List +from typing import List, Optional -from fastapi import APIRouter, Path, Security +from fastapi import APIRouter, Path, Query, Security +from typing_extensions import Annotated from app.api import crud from app.api.deps import get_current_access @@ -25,8 +26,12 @@ async def get_access(access_id: int = Path(..., gt=0), _=Security(get_current_ac @router.get("/", response_model=List[AccessRead], summary="Get the list of all accesses") -async def fetch_accesses(_=Security(get_current_access, scopes=[AccessType.admin])): +async def fetch_accesses( + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + _=Security(get_current_access, scopes=[AccessType.admin]), +): """ Retrieves the list of all accesses and their information """ - return await crud.fetch_all(accesses) + return await crud.fetch_all(accesses, limit=limit, offset=offset) diff --git a/src/app/api/endpoints/alerts.py b/src/app/api/endpoints/alerts.py index f006e495..95116a73 100644 --- a/src/app/api/endpoints/alerts.py +++ b/src/app/api/endpoints/alerts.py @@ -5,10 +5,10 @@ from functools import partial from string import Template -from typing import List, cast +from typing import List, Optional, cast -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Security, status -from sqlalchemy import select +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Path, Query, Security, status +from typing_extensions import Annotated from app.api import crud from app.api.crud.authorizations import check_group_read, is_admin_access @@ -18,7 +18,7 @@ from app.api.endpoints.notifications import send_notification from app.api.endpoints.recipients import fetch_recipients_for_group from app.api.external import post_request -from app.db import alerts, events, media +from app.db import alerts, media from app.models import Access, AccessType, Alert, Device, Event from app.schemas import AlertBase, AlertIn, AlertOut, DeviceOut, NotificationIn, RecipientOut @@ -123,19 +123,22 @@ async def get_alert( @router.get("/", response_model=List[AlertOut], summary="Get the list of all alerts") async def fetch_alerts( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ Retrieves the list of all alerts and their information """ - if await is_admin_access(requester.id): - return await crud.fetch_all(alerts) - else: - retrieved_alerts = ( - session.query(Alert).join(Device).join(Access).filter(Access.group_id == requester.group_id).all() - ) - retrieved_alerts = [x.__dict__ for x in retrieved_alerts] - return retrieved_alerts + return await crud.fetch_all( + alerts, + query=None + if await is_admin_access(requester.id) + else session.query(Alert).join(Device).join(Access).filter(Access.group_id == requester.group_id), + limit=limit, + offset=offset, + ) @router.delete("/{alert_id}/", response_model=AlertOut, summary="Delete a specific alert") @@ -148,25 +151,15 @@ async def delete_alert(alert_id: int = Path(..., gt=0), _=Security(get_current_a @router.get("/ongoing", response_model=List[AlertOut], summary="Get the list of ongoing alerts") async def fetch_ongoing_alerts( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ Retrieves the list of ongoing alerts and their information """ - if await is_admin_access(requester.id): - query = ( - alerts.select().where(alerts.c.event_id.in_(select([events.c.id]).where(events.c.end_ts.is_(None)))) - ).order_by(alerts.c.id.desc()) - - return (await crud.base.database.fetch_all(query=query.limit(50)))[::-1] - else: - retrieved_alerts = ( - session.query(Alert) - .join(Event) - .filter(Event.end_ts.is_(None)) - .join(Device) - .join(Access) - .filter(Access.group_id == requester.group_id) - ) - retrieved_alerts = [x.__dict__ for x in retrieved_alerts.all()] - return retrieved_alerts + query = session.query(Alert).join(Event).filter(Event.end_ts.is_(None)) + if not await is_admin_access(requester.id): + query = query.join(Device).join(Access).filter(Access.group_id == requester.group_id) + return await crud.fetch_all(alerts, query=query, limit=limit, offset=offset) diff --git a/src/app/api/endpoints/devices.py b/src/app/api/endpoints/devices.py index cda6ccf9..a6e8848b 100644 --- a/src/app/api/endpoints/devices.py +++ b/src/app/api/endpoints/devices.py @@ -4,9 +4,10 @@ # See LICENSE or go to for full license details. from datetime import datetime -from typing import List, cast +from typing import List, Optional, cast -from fastapi import APIRouter, Depends, HTTPException, Path, Security, status +from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security, status +from typing_extensions import Annotated from app.api import crud from app.api.crud.authorizations import is_admin_access @@ -80,18 +81,22 @@ async def get_my_device(me: DeviceOut = Security(get_current_device, scopes=["de @router.get("/", response_model=List[DeviceOut], summary="Get the list of all devices") async def fetch_devices( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ Retrieves the list of all devices and their information """ - if await is_admin_access(requester.id): - return await crud.fetch_all(devices) - else: - retrieved_devices = session.query(Device).join(Access).filter(Access.group_id == requester.group_id).all() - retrieved_devices = [x.__dict__ for x in retrieved_devices] - - return retrieved_devices + return await crud.fetch_all( + devices, + query=None + if await is_admin_access(requester.id) + else session.query(Device).join(Access).filter(Access.group_id == requester.group_id), + limit=limit, + offset=offset, + ) @router.put("/{device_id}/", response_model=DeviceOut, summary="Update information about a specific device") @@ -115,11 +120,15 @@ async def delete_device(device_id: int = Path(..., gt=0), _=Security(get_current @router.get( "/my-devices", response_model=List[DeviceOut], summary="Get the list of all devices belonging to the current user" ) -async def fetch_my_devices(me: UserRead = Security(get_current_user, scopes=[AccessType.admin, AccessType.user])): +async def fetch_my_devices( + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + me: UserRead = Security(get_current_user, scopes=[AccessType.admin, AccessType.user]), +): """ Retrieves the list of all devices and the information which are owned by the current user """ - return await crud.fetch_all(devices, {"owner_id": me.id}) + return await crud.fetch_all(devices, {"owner_id": me.id}, limit=limit, offset=offset) @router.put("/heartbeat", response_model=DeviceOut, summary="Update the last ping of the current device") diff --git a/src/app/api/endpoints/events.py b/src/app/api/endpoints/events.py index 3e9fdb2d..f1fc78e6 100644 --- a/src/app/api/endpoints/events.py +++ b/src/app/api/endpoints/events.py @@ -3,11 +3,11 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List, cast +from typing import List, Optional, cast -from fastapi import APIRouter, Depends, Path, Security, status +from fastapi import APIRouter, Depends, Path, Query, Security, status from pydantic import PositiveInt -from sqlalchemy import and_ +from typing_extensions import Annotated from app.api import crud from app.api.crud.authorizations import check_group_read, check_group_update, is_admin_access @@ -44,40 +44,43 @@ async def get_event( @router.get("/", response_model=List[EventOut], summary="Get the list of all events") async def fetch_events( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ Retrieves the list of all events and their information """ - if await is_admin_access(requester.id): - return await crud.fetch_all(events) - else: - retrieved_events = ( - session.query(Event).join(Alert).join(Device).join(Access).filter(Access.group_id == requester.group_id) - ) - retrieved_events = [x.__dict__ for x in retrieved_events.all()] - return retrieved_events + return await crud.fetch_all( + events, + query=None + if await is_admin_access(requester.id) + else session.query(Event).join(Alert).join(Device).join(Access).filter(Access.group_id == requester.group_id), + limit=limit, + offset=offset, + ) @router.get("/past", response_model=List[EventOut], summary="Get the list of all past events") async def fetch_past_events( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ - Retrieves the list of all events and their information + Retrieves the list of all events without end timestamp and their information """ - if await is_admin_access(requester.id): - return await crud.fetch_all(events, exclusions={"end_ts": None}) - else: - retrieved_events = ( - session.query(Event) - .join(Alert) - .join(Device) - .join(Access) - .filter(and_(Access.group_id == requester.group_id, Event.end_ts.isnot(None))) - ) - retrieved_events = [x.__dict__ for x in retrieved_events.all()] - return retrieved_events + return await crud.fetch_all( + events, + exclusions={"end_ts": None}, + query=None + if await is_admin_access(requester.id) + else session.query(Event).join(Alert).join(Device).join(Access).filter(Access.group_id == requester.group_id), + limit=limit, + offset=offset, + ) @router.put("/{event_id}/", response_model=EventOut, summary="Update information about a specific event") @@ -122,28 +125,30 @@ async def delete_event( "/unacknowledged", response_model=List[EventOut], summary="Get the list of events that haven't been acknowledged" ) async def fetch_unacknowledged_events( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ - Retrieves the list of non confirmed alerts and their information + Retrieves the list of unacknowledged alerts and their information """ - if await is_admin_access(requester.id): - return await crud.fetch_all(events, {"is_acknowledged": False}) - else: - retrieved_events = ( - session.query(Event) - .join(Alert) - .join(Device) - .join(Access) - .filter(and_(Access.group_id == requester.group_id, Event.is_acknowledged.is_(False))) - ) - retrieved_events = [x.__dict__ for x in retrieved_events.all()] - return retrieved_events + return await crud.fetch_all( + events, + {"is_acknowledged": False}, + query=None + if await is_admin_access(requester.id) + else session.query(Event).join(Alert).join(Device).join(Access).filter(Access.group_id == requester.group_id), + limit=limit, + offset=offset, + ) @router.get("/{event_id}/alerts", response_model=List[AlertOut], summary="Get the list of alerts for event") async def fetch_alerts_for_event( event_id: PositiveInt, + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db), ): @@ -152,4 +157,4 @@ async def fetch_alerts_for_event( """ requested_group_id = await get_entity_group_id(events, event_id) await check_group_read(requester.id, cast(int, requested_group_id)) - return await crud.base.database.fetch_all(query=alerts.select().where(alerts.c.event_id == event_id)) + return await crud.fetch_all(alerts, {"event_id": event_id}, limit=limit, offset=offset) diff --git a/src/app/api/endpoints/groups.py b/src/app/api/endpoints/groups.py index 54dcf394..7814f250 100644 --- a/src/app/api/endpoints/groups.py +++ b/src/app/api/endpoints/groups.py @@ -3,9 +3,10 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List +from typing import List, Optional -from fastapi import APIRouter, Path, Security, status +from fastapi import APIRouter, Path, Query, Security, status +from typing_extensions import Annotated from app.api import crud from app.api.deps import get_current_access @@ -35,11 +36,14 @@ async def get_group(group_id: int = Path(..., gt=0)): @router.get("/", response_model=List[GroupOut], summary="Get the list of all groups") -async def fetch_groups(): +async def fetch_groups( + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, +): """ Retrieves the list of all groups and their information """ - return await crud.fetch_all(groups) + return await crud.fetch_all(groups, limit=limit, offset=offset) @router.put("/{group_id}/", response_model=GroupOut, summary="Update information about a specific group") diff --git a/src/app/api/endpoints/installations.py b/src/app/api/endpoints/installations.py index 14d143d1..26d5ff28 100644 --- a/src/app/api/endpoints/installations.py +++ b/src/app/api/endpoints/installations.py @@ -4,16 +4,17 @@ # See LICENSE or go to for full license details. from datetime import datetime -from typing import List, cast +from typing import List, Optional, cast -from fastapi import APIRouter, Depends, Path, Security, status +from fastapi import APIRouter, Depends, Path, Query, Security, status from sqlalchemy import and_, or_ +from typing_extensions import Annotated from app.api import crud from app.api.crud.authorizations import check_group_read, check_group_update, is_admin_access from app.api.crud.groups import get_entity_group_id from app.api.deps import get_current_access, get_db -from app.db import installations +from app.db import installations, sites from app.models import AccessType, Installation, Site from app.schemas import InstallationIn, InstallationOut, InstallationUpdate @@ -49,19 +50,22 @@ async def get_installation( @router.get("/", response_model=List[InstallationOut], summary="Get the list of all installations") async def fetch_installations( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ Retrieves the list of all installations and their information """ - if await is_admin_access(requester.id): - return await crud.fetch_all(installations) - else: - retrieved_installations = ( - session.query(Installation).join(Site).filter(Site.group_id == requester.group_id).all() - ) - retrieved_installations = [x.__dict__ for x in retrieved_installations] - return retrieved_installations + return await crud.fetch_all( + installations, + query=None + if await is_admin_access(requester.id) + else session.query(Installation).join(Site).filter(Site.group_id == requester.group_id), + limit=limit, + offset=offset, + ) @router.put( @@ -93,29 +97,33 @@ async def delete_installation( @router.get("/site-devices/{site_id}", response_model=List[int], summary="Get all devices related to a specific site") async def get_active_devices_on_site( site_id: int = Path(..., gt=0), + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db), ): """ Based on a site_id, retrieves the list of all the related devices and their information """ + requested_group_id = await get_entity_group_id(sites, site_id) + await check_group_read(requester.id, cast(int, requested_group_id)) current_ts = datetime.utcnow() - - query = ( - session.query(Installation) - .join(Site) - .filter( - and_( - Site.id == site_id, - Installation.start_ts <= current_ts, - or_(Installation.end_ts.is_(None), Installation.end_ts >= current_ts), - ) + return [ + item["device_id"] + for item in await crud.fetch_all( + installations, + query=( + session.query(Installation) + .join(Site) + .filter( + and_( + Site.id == site_id, + Installation.start_ts <= current_ts, + or_(Installation.end_ts.is_(None), Installation.end_ts >= current_ts), + ) + ) + ), + limit=limit, + offset=offset, ) - ) - - if not await is_admin_access(requester.id): - # Restrict on the group_id of the requester - query = query.filter(Site.group_id == requester.group_id) - - retrieved_device_ids = [x.__dict__["device_id"] for x in query.all()] - return retrieved_device_ids + ] diff --git a/src/app/api/endpoints/media.py b/src/app/api/endpoints/media.py index 8d7d5f3d..ef776c0d 100644 --- a/src/app/api/endpoints/media.py +++ b/src/app/api/endpoints/media.py @@ -8,7 +8,8 @@ from typing import Any, Dict, List, Optional, cast import magic -from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Path, Security, UploadFile, status +from fastapi import APIRouter, BackgroundTasks, Depends, File, HTTPException, Path, Query, Security, UploadFile, status +from typing_extensions import Annotated from app.api import crud from app.api.crud.authorizations import check_group_read, is_admin_access @@ -88,19 +89,22 @@ async def get_media( @router.get("/", response_model=List[MediaOut], summary="Get the list of all media") async def fetch_media( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ Retrieves the list of all media and their information """ - if await is_admin_access(requester.id): - return await crud.fetch_all(media) - else: - retrieved_media = ( - session.query(Media).join(Device).join(Access).filter(Access.group_id == requester.group_id).all() - ) - retrieved_media = [x.__dict__ for x in retrieved_media] - return retrieved_media + return await crud.fetch_all( + media, + query=media.select() + if await is_admin_access(requester.id) + else session.query(Media).join(Device).join(Access).where(Access.group_id == requester.group_id), + limit=limit, + offset=offset, + ) @router.delete("/{media_id}/", response_model=MediaOut, summary="Delete a specific media") diff --git a/src/app/api/endpoints/notifications.py b/src/app/api/endpoints/notifications.py index 8751ef5d..fb866a47 100644 --- a/src/app/api/endpoints/notifications.py +++ b/src/app/api/endpoints/notifications.py @@ -3,9 +3,10 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List +from typing import List, Optional -from fastapi import APIRouter, HTTPException, Path, Security, status +from fastapi import APIRouter, HTTPException, Path, Query, Security, status +from typing_extensions import Annotated from app.api import crud from app.api.deps import get_current_access @@ -49,11 +50,14 @@ async def get_notification(notification_id: int = Path(..., gt=0)): @router.get("/", response_model=List[NotificationOut], summary="Get the list of all notifications") -async def fetch_notifications(): +async def fetch_notifications( + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, +): """ Retrieves the list of all notifications and their information """ - return await crud.fetch_all(notifications) + return await crud.fetch_all(notifications, limit=limit, offset=offset) @router.delete("/{notification_id}/", response_model=NotificationOut, summary="Delete a specific notification") diff --git a/src/app/api/endpoints/recipients.py b/src/app/api/endpoints/recipients.py index 216bc51e..ce603919 100644 --- a/src/app/api/endpoints/recipients.py +++ b/src/app/api/endpoints/recipients.py @@ -3,10 +3,11 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List +from typing import List, Optional -from fastapi import APIRouter, Security, status +from fastapi import APIRouter, Query, Security, status from pydantic import PositiveInt +from typing_extensions import Annotated from app.api import crud from app.api.deps import get_current_access @@ -39,11 +40,14 @@ async def get_recipient(recipient_id: PositiveInt): @router.get("/", response_model=List[RecipientOut], summary="Get the list of all recipients") -async def fetch_recipients(): +async def fetch_recipients( + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, +): """ Retrieves the list of all recipients and their information """ - return await crud.fetch_all(recipients) + return await crud.fetch_all(recipients, limit=limit, offset=offset) @router.get( @@ -51,11 +55,15 @@ async def fetch_recipients(): response_model=List[RecipientOut], summary="Get the list of all recipients for the given group", ) -async def fetch_recipients_for_group(group_id: PositiveInt): +async def fetch_recipients_for_group( + group_id: PositiveInt, + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, +): """ Retrieves the list of all recipients for the given group and their information """ - return await crud.fetch_all(recipients, {"group_id": group_id}) + return await crud.fetch_all(recipients, {"group_id": group_id}, limit=limit, offset=offset) @router.put("/{recipient_id}/", response_model=RecipientOut, summary="Update information about a specific recipient") diff --git a/src/app/api/endpoints/sites.py b/src/app/api/endpoints/sites.py index c8c5dd43..b1f3ca93 100644 --- a/src/app/api/endpoints/sites.py +++ b/src/app/api/endpoints/sites.py @@ -3,9 +3,10 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List, cast +from typing import List, Optional, cast -from fastapi import APIRouter, Depends, Path, Security, status +from fastapi import APIRouter, Depends, Path, Query, Security, status +from typing_extensions import Annotated from app.api import crud from app.api.crud.authorizations import check_group_read, check_group_update, is_admin_access @@ -62,15 +63,20 @@ async def get_site( @router.get("/", response_model=List[SiteOut], summary="Get the list of all sites in your group") async def fetch_sites( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ Retrieves the list of all sites and their information """ - if await is_admin_access(requester.id): - return await crud.fetch_all(sites) - else: - return await crud.fetch_all(sites, {"group_id": requester.group_id}) + return await crud.fetch_all( + sites, + query_filters=None if await is_admin_access(requester.id) else {"group_id": requester.group_id}, + limit=limit, + offset=offset, + ) @router.put("/{site_id}/", response_model=SiteOut, summary="Update information about a specific site") diff --git a/src/app/api/endpoints/users.py b/src/app/api/endpoints/users.py index c54a97d7..dfff5b2e 100644 --- a/src/app/api/endpoints/users.py +++ b/src/app/api/endpoints/users.py @@ -3,9 +3,10 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List +from typing import List, Optional -from fastapi import APIRouter, Depends, Path, Security, status +from fastapi import APIRouter, Depends, Path, Query, Security, status +from typing_extensions import Annotated from app.api import crud from app.api.crud.authorizations import is_admin_access @@ -68,17 +69,22 @@ async def get_user(user_id: int = Path(..., gt=0), _=Security(get_current_user, @router.get("/", response_model=List[UserRead], summary="Get the list of all users") async def fetch_users( - requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), session=Depends(get_db) + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + requester=Security(get_current_access, scopes=[AccessType.admin, AccessType.user]), + session=Depends(get_db), ): """ Retrieves the list of all users and their information """ - if await is_admin_access(requester.id): - return await crud.fetch_all(users) - else: - retrieved_users = session.query(User).join(Access).filter(Access.group_id == requester.group_id).all() - retrieved_users = [x.__dict__ for x in retrieved_users] - return retrieved_users + return await crud.fetch_all( + users, + query=None + if await is_admin_access(requester.id) + else session.query(User).join(Access).filter(Access.group_id == requester.group_id), + limit=limit, + offset=offset, + ) @router.put("/{user_id}/", response_model=UserRead, summary="Update information about a specific user") diff --git a/src/app/api/endpoints/webhooks.py b/src/app/api/endpoints/webhooks.py index 1de46233..61515cd2 100644 --- a/src/app/api/endpoints/webhooks.py +++ b/src/app/api/endpoints/webhooks.py @@ -3,9 +3,10 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import List +from typing import List, Optional -from fastapi import APIRouter, Path, Security, status +from fastapi import APIRouter, Path, Query, Security, status +from typing_extensions import Annotated from app.api import crud from app.api.deps import get_current_access @@ -35,11 +36,15 @@ async def get_webhook(webhook_id: int = Path(..., gt=0), _=Security(get_current_ @router.get("/", response_model=List[WebhookOut], summary="Get the list of all webhooks") -async def fetch_webhooks(_=Security(get_current_access, scopes=["admin"])): +async def fetch_webhooks( + limit: Annotated[int, Query(description="maximum number of items", ge=1, le=1000)] = 50, + offset: Annotated[Optional[int], Query(description="number of items to skip", ge=0)] = None, + _=Security(get_current_access, scopes=["admin"]), +): """ Retrieves the list of all webhooks and their information """ - return await crud.fetch_all(webhooks) + return await crud.fetch_all(webhooks, limit=limit, offset=offset) @router.put("/{webhook_id}/", response_model=WebhookOut, summary="Update information about a specific webhook") diff --git a/src/tests/routes/test_groups.py b/src/tests/routes/test_groups.py index 89ec1447..67f3b7fa 100644 --- a/src/tests/routes/test_groups.py +++ b/src/tests/routes/test_groups.py @@ -97,12 +97,30 @@ async def test_get_group(test_app_asyncio, init_test_db, group_id, status_code, assert response_json == GROUP_TABLE[group_id - 1] +@pytest.mark.parametrize( + "limit, offset, expected", + [ + (None, None, GROUP_TABLE[-50:]), + (50, None, GROUP_TABLE[-50:]), + (None, 0, GROUP_TABLE[-50:]), + (50, 0, GROUP_TABLE[-50:]), + (10, 0, GROUP_TABLE[-10:]), + (10, 10, GROUP_TABLE[-20:-10]), + ], +) @pytest.mark.asyncio -async def test_fetch_groups(test_app_asyncio, init_test_db): - response = await test_app_asyncio.get("/groups/") +async def test_fetch_groups(test_app_asyncio, init_test_db, limit, offset, expected): + query = [] + if limit is not None: + query.append(f"limit={limit}") + if offset is not None: + query.append(f"offset={offset}") + if not query: + response = await test_app_asyncio.get("/groups/") + else: + response = await test_app_asyncio.get("/groups/?" + "&".join(query)) assert response.status_code == 200 - response_json = response.json() - assert all(result == entry for result, entry in zip(response_json, GROUP_TABLE[-50:])) + assert response.json() == expected @pytest.mark.parametrize( diff --git a/src/tests/routes/test_installations.py b/src/tests/routes/test_installations.py index aed6b4d4..a024a8d4 100644 --- a/src/tests/routes/test_installations.py +++ b/src/tests/routes/test_installations.py @@ -425,9 +425,10 @@ async def test_delete_installation( [0, 1, [1], 200, None], [1, 1, [1], 200, None], [4, 2, [2], 200, None], - [1, 999, [], 200, None], # TODO: this should fail since the site doesn't exist + [1, 999, [], 404, "Table sites has no entry with id=999"], [1, 0, [], 422, None], [2, 1, [], 403, "Your access scope is not compatible with this operation."], + [4, 1, [], 403, "This access can't read resources from group_id=1"] ], ) @pytest.mark.asyncio From a94ceb92c490b3e983fd5bb4b50770b55b35ffdb Mon Sep 17 00:00:00 2001 From: Bruno Lenzi Date: Mon, 7 Aug 2023 20:24:08 +0200 Subject: [PATCH 2/2] fix: black --- src/tests/routes/test_installations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/routes/test_installations.py b/src/tests/routes/test_installations.py index a024a8d4..21ded5ad 100644 --- a/src/tests/routes/test_installations.py +++ b/src/tests/routes/test_installations.py @@ -428,7 +428,7 @@ async def test_delete_installation( [1, 999, [], 404, "Table sites has no entry with id=999"], [1, 0, [], 422, None], [2, 1, [], 403, "Your access scope is not compatible with this operation."], - [4, 1, [], 403, "This access can't read resources from group_id=1"] + [4, 1, [], 403, "This access can't read resources from group_id=1"], ], ) @pytest.mark.asyncio