diff --git a/CHANGELOG.md b/CHANGELOG.md index 647a8713..96c7d4bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ These are the section headers that we use: - Added `POST /api/v1/token` endpoint to generate a new API token for a user. ([#138](https://github.com/argilla-io/argilla-server/pull/138)) - Added `GET /api/v1/me` endpoint to get the current user information. ([#140](https://github.com/argilla-io/argilla-server/pull/140)) +- Added `GET /api/v1/users` endpoint to get a list of all users. ([#142](https://github.com/argilla-io/argilla-server/pull/142)) ## [Unreleased]() diff --git a/src/argilla_server/apis/v1/handlers/users.py b/src/argilla_server/apis/v1/handlers/users.py index 5860631f..94c79808 100644 --- a/src/argilla_server/apis/v1/handlers/users.py +++ b/src/argilla_server/apis/v1/handlers/users.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Request, Security, status @@ -20,9 +21,8 @@ from argilla_server import models, telemetry from argilla_server.contexts import accounts from argilla_server.database import get_async_db -from argilla_server.models import User from argilla_server.policies import UserPolicyV1, authorize -from argilla_server.schemas.v1.users import User +from argilla_server.schemas.v1.users import User, Users from argilla_server.schemas.v1.workspaces import Workspaces from argilla_server.security import auth @@ -36,6 +36,19 @@ async def get_current_user(request: Request, current_user: models.User = Securit return current_user +@router.get("/users", response_model=Users) +async def list_users( + *, + db: AsyncSession = Depends(get_async_db), + current_user: models.User = Security(auth.get_current_user), +): + await authorize(current_user, UserPolicyV1.list) + + users = await accounts.list_users(db) + + return Users(items=users) + + @router.get("/users/{user_id}/workspaces", response_model=Workspaces) async def list_user_workspaces( *, diff --git a/src/argilla_server/contexts/accounts.py b/src/argilla_server/contexts/accounts.py index b128f07b..7e0ef28b 100644 --- a/src/argilla_server/contexts/accounts.py +++ b/src/argilla_server/contexts/accounts.py @@ -110,6 +110,8 @@ async def get_user_by_api_key(db: AsyncSession, api_key: str) -> Union[User, Non async def list_users(db: "AsyncSession") -> Sequence[User]: + # TODO: After removing API v0 implementation we can remove the workspaces eager loading + # because is not used in the new API v1 endpoints. result = await db.execute(select(User).order_by(User.inserted_at.asc()).options(selectinload(User.workspaces))) return result.scalars().all() diff --git a/src/argilla_server/policies.py b/src/argilla_server/policies.py index 196ef159..97dbf183 100644 --- a/src/argilla_server/policies.py +++ b/src/argilla_server/policies.py @@ -129,6 +129,10 @@ async def is_allowed(actor: User) -> bool: class UserPolicyV1: + @classmethod + async def list(cls, actor: User) -> bool: + return actor.is_owner + @classmethod async def list_workspaces(cls, actor: User) -> bool: return actor.is_owner diff --git a/src/argilla_server/schemas/v1/users.py b/src/argilla_server/schemas/v1/users.py index f728d7c1..4ad4b133 100644 --- a/src/argilla_server/schemas/v1/users.py +++ b/src/argilla_server/schemas/v1/users.py @@ -13,7 +13,7 @@ # limitations under the License. from datetime import datetime -from typing import Optional +from typing import List, Optional from uuid import UUID from argilla_server.enums import UserRole @@ -32,3 +32,7 @@ class User(BaseModel): class Config: orm_mode = True + + +class Users(BaseModel): + items: List[User] diff --git a/tests/unit/api/v1/users/test_list_users.py b/tests/unit/api/v1/users/test_list_users.py new file mode 100644 index 00000000..358d710d --- /dev/null +++ b/tests/unit/api/v1/users/test_list_users.py @@ -0,0 +1,81 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from argilla_server.models import User +from httpx import AsyncClient + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestListUsers: + def url(self) -> str: + return "/api/v1/users" + + async def test_list_users(self, async_client: AsyncClient, owner: User, owner_auth_header: dict): + user_a, user_b = await UserFactory.create_batch(2) + + response = await async_client.get(self.url(), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "id": str(owner.id), + "first_name": owner.first_name, + "last_name": owner.last_name, + "username": owner.username, + "role": owner.role, + "api_key": owner.api_key, + "inserted_at": owner.inserted_at.isoformat(), + "updated_at": owner.updated_at.isoformat(), + }, + { + "id": str(user_a.id), + "first_name": user_a.first_name, + "last_name": user_a.last_name, + "username": user_a.username, + "role": user_a.role, + "api_key": user_a.api_key, + "inserted_at": user_a.inserted_at.isoformat(), + "updated_at": user_a.updated_at.isoformat(), + }, + { + "id": str(user_b.id), + "first_name": user_b.first_name, + "last_name": user_b.last_name, + "username": user_b.username, + "role": user_b.role, + "api_key": user_b.api_key, + "inserted_at": user_b.inserted_at.isoformat(), + "updated_at": user_b.updated_at.isoformat(), + }, + ] + } + + async def test_list_users_without_authentication(self, async_client: AsyncClient): + response = await async_client.get(self.url()) + + assert response.status_code == 401 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_list_users_with_unauthorized_role(self, async_client: AsyncClient, user_role: UserRole): + user = await UserFactory.create(role=user_role) + + response = await async_client.get(self.url(), headers={API_KEY_HEADER_NAME: user.api_key}) + + assert response.status_code == 403