Skip to content

Commit

Permalink
Add POST /client/all endpoint to list all user clients
Browse files Browse the repository at this point in the history
  • Loading branch information
kuyugama committed Oct 26, 2024
1 parent 8614260 commit 9e1a917
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 8 deletions.
37 changes: 31 additions & 6 deletions app/client/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
from fastapi import APIRouter, Depends

from app.dependencies import auth_required, get_page, get_size
from app.utils import pagination, pagination_dict
from app.utils import pagination, paginated_response
from app.schemas import ClientResponse
from app.database import get_session
from app.models import Client, User
from app.client import service
from app import constants

from app.client.dependencies import (
validate_unverified_client,
validate_client_create,
validate_user_client,
validate_client,
validate_unverified_client,
)
from app.client.schemas import (
ClientPaginationResponse,
ClientFullResponse,
ListAllClientsArgs,
ClientCreate,
ClientUpdate,
)
Expand All @@ -41,10 +42,34 @@ async def list_user_clients(
total = await service.count_user_clients(session, user, offset, limit)
clients = await service.list_user_clients(session, user, offset, limit)

return {
"pagination": pagination_dict(total, page, limit),
"list": clients.all(),
}
return paginated_response(clients.all(), total, page, limit)


@router.post(
"/all",
summary="List all clients",
response_model=ClientPaginationResponse,
dependencies=[
Depends(
auth_required(
scope=[constants.SCOPE_READ_CLIENT_LIST],
permissions=[constants.PERMISSION_CLIENT_LIST_ALL],
)
)
],
)
async def list_all_clients(
args: ListAllClientsArgs,
page: int = Depends(get_page),
size: int = Depends(get_size),
session: AsyncSession = Depends(get_session),
):
limit, offset = pagination(page, size)

total = await service.count_all_clients(session, args)
clients = await service.list_all_clients(session, args, offset, limit)

return paginated_response(clients.all(), total, page, limit)


@router.get(
Expand Down
4 changes: 4 additions & 0 deletions app/client/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ def validate_endpoint(cls, v: HttpUrl | None) -> HttpUrl | None:
)

return v


class ListAllClientsArgs(CustomModel):
query: str | None = Field(None, description="Search by name")
36 changes: 34 additions & 2 deletions app/client/service.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import secrets
import uuid

from sqlalchemy import select, func, ScalarResult, Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy import select, func, ScalarResult

from app.client.schemas import ClientCreate, ClientUpdate
from app.client.schemas import ClientCreate, ClientUpdate, ListAllClientsArgs
from app.models import User, Client
from app.utils import utcnow

Expand Down Expand Up @@ -122,3 +122,35 @@ async def verify_client(session: AsyncSession, client: Client) -> Client:
client.verified = True
await session.commit()
return client


def apply_all_clients_filters(
query: Select, args: ListAllClientsArgs
) -> Select:
if args.query is not None:
query = query.filter(Client.name.ilike(f"%{args.query}%"))

return query


async def count_all_clients(
session: AsyncSession, args: ListAllClientsArgs
) -> int:
return await session.scalar(
apply_all_clients_filters(select(func.count(Client.id)), args)
)


async def list_all_clients(
session: AsyncSession, args: ListAllClientsArgs, offset: int, limit: int
) -> ScalarResult[Client]:
return await session.scalars(
apply_all_clients_filters(
select(Client)
.options(joinedload(Client.user))
.order_by(Client.created.desc())
.offset(offset)
.limit(limit),
args,
)
)
2 changes: 2 additions & 0 deletions app/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@
PERMISSION_CLIENT_DELETE = "client:delete"
PERMISSION_CLIENT_VERIFY = "client:verify"
PERMISSION_CLIENT_DELETE_ADMIN = "client:delete_admin"
PERMISSION_CLIENT_LIST_ALL = "client:list_all"

USER_PERMISSIONS = [
PERMISSION_EDIT_CREATE,
Expand Down Expand Up @@ -378,6 +379,7 @@
PERMISSION_COLLECTION_DELETE_MODERATOR,
PERMISSION_EDIT_UPDATE_MODERATOR,
PERMISSION_CLIENT_VERIFY,
PERMISSION_CLIENT_LIST_ALL,
]

ADMIN_PERMISSIONS = [
Expand Down
19 changes: 19 additions & 0 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from starlette.middleware.base import BaseHTTPMiddleware
from dateutil.relativedelta import relativedelta
from fastapi.responses import JSONResponse
from sqlalchemy.orm import DeclarativeBase
from datetime import timezone, timedelta
from fastapi import FastAPI, Request
from collections.abc import Sequence
from datetime import datetime, UTC
from app.models import AuthToken
from functools import lru_cache
Expand All @@ -16,10 +18,15 @@
import asyncio
import secrets
import bcrypt
import typing
import math
import re


if typing.TYPE_CHECKING:
from app.schemas import CustomModel


# Timeout middleware (class name is pretty self explanatory)
class TimeoutMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI, timeout: int) -> None:
Expand Down Expand Up @@ -266,6 +273,18 @@ def pagination_dict(total, page, limit):
}


def paginated_response(
items: Sequence[typing.Union[DeclarativeBase, "CustomModel"]],
total: int,
page: int,
limit: int,
) -> dict[str, dict[str, int] | list]:
return {
"list": items,
"pagination": pagination_dict(total, page, limit),
}


# Convert month to season str
def get_season(date):
# Anime seasons start from first month of the year
Expand Down
24 changes: 24 additions & 0 deletions tests/client/test_list_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from tests.client_requests import request_list_all_clients


async def test_normal(client, moderator_token, test_thirdparty_client):
response = await request_list_all_clients(client, moderator_token)
print(response.json())
assert response.status_code == 200

assert response.json()["pagination"] == {
"page": 1,
"pages": 1,
"total": 1,
}

response_client = response.json()["list"][0]
assert response_client["description"] == test_thirdparty_client.description
assert response_client["reference"] == test_thirdparty_client.reference
assert response_client["name"] == test_thirdparty_client.name


async def test_no_permission(client, test_token):
response = await request_list_all_clients(client, test_token)
print(response.json())
assert response.status_code == 403
2 changes: 2 additions & 0 deletions tests/client_requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .auth import request_login

from .client import request_client_full_info
from .client import request_list_all_clients
from .client import request_client_create
from .client import request_client_update
from .client import request_client_verify
Expand Down Expand Up @@ -131,6 +132,7 @@
"request_signup",
"request_login",
"request_client_full_info",
"request_list_all_clients",
"request_client_create",
"request_client_update",
"request_client_verify",
Expand Down
10 changes: 10 additions & 0 deletions tests/client_requests/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def request_list_clients(client, token: str):
)


def request_list_all_clients(client, token: str, query: str | None = None):
return client.post(
"/client/all",
headers={"Auth": token},
json={
"query": query
}
)


def request_client_verify(client, token: str, client_reference: str):
return client.post(
f"/client/{client_reference}/verify",
Expand Down

0 comments on commit 9e1a917

Please sign in to comment.