diff --git a/CHANGELOG.md b/CHANGELOG.md index a3f8233b..c210e2cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ These are the section headers that we use: - Added `POST /api/v1/token` endpoint to generate a new API token for a user. ([#138](https://github.com/argilla-io/argilla-server/pull/138)) - Added `GET /api/v1/me` endpoint to get the current user information. ([#140](https://github.com/argilla-io/argilla-server/pull/140)) - Added `GET /api/v1/users` endpoint to get a list of all users. ([#142](https://github.com/argilla-io/argilla-server/pull/142)) +- Added `GET /api/v1/users/:user_id` endpoint to get a specific user. ([#166](https://github.com/argilla-io/argilla-server/pull/166)) - Added `POST /api/v1/users` endpoint to create a new user. ([#146](https://github.com/argilla-io/argilla-server/pull/146)) - Added `DELETE /api/v1/users` endpoint to delete a user. ([#148](https://github.com/argilla-io/argilla-server/pull/148)) - Added `POST /api/v1/workspaces` endpoint to create a new workspace. ([#150](https://github.com/argilla-io/argilla-server/pull/150)) diff --git a/src/argilla_server/apis/v1/handlers/users.py b/src/argilla_server/apis/v1/handlers/users.py index 6cfce784..23c53b52 100644 --- a/src/argilla_server/apis/v1/handlers/users.py +++ b/src/argilla_server/apis/v1/handlers/users.py @@ -37,6 +37,25 @@ async def get_current_user(request: Request, current_user: models.User = Securit return current_user +@router.get("/users/{user_id}", response_model=User) +async def get_user( + *, + db: AsyncSession = Depends(get_async_db), + user_id: UUID, + current_user: models.User = Security(auth.get_current_user), +): + await authorize(current_user, UserPolicyV1.get) + + user = await accounts.get_user_by_id(db, user_id) + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User with id `{user_id}` not found", + ) + + return user + + @router.get("/users", response_model=Users) async def list_users( *, diff --git a/src/argilla_server/policies.py b/src/argilla_server/policies.py index aeba9dbc..6164ee72 100644 --- a/src/argilla_server/policies.py +++ b/src/argilla_server/policies.py @@ -158,6 +158,10 @@ async def is_allowed(actor: User) -> bool: class UserPolicyV1: + @classmethod + async def get(cls, actor: User) -> bool: + return actor.is_owner + @classmethod async def list(cls, actor: User) -> bool: return actor.is_owner diff --git a/src/argilla_server/schemas/v1/users.py b/src/argilla_server/schemas/v1/users.py index b04b7be5..1b93d7f6 100644 --- a/src/argilla_server/schemas/v1/users.py +++ b/src/argilla_server/schemas/v1/users.py @@ -30,6 +30,8 @@ class User(BaseModel): last_name: Optional[str] username: str role: UserRole + # TODO: We need to move `api_key` outside of this schema and think about a more + # secure way to expose it, along with ways to expire it and create new API keys. api_key: str inserted_at: datetime updated_at: datetime diff --git a/tests/unit/api/v1/users/test_get_user.py b/tests/unit/api/v1/users/test_get_user.py new file mode 100644 index 00000000..da509891 --- /dev/null +++ b/tests/unit/api/v1/users/test_get_user.py @@ -0,0 +1,71 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import UUID, uuid4 + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from httpx import AsyncClient + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestGetUser: + def url(self, user_id: UUID) -> str: + return f"/api/v1/users/{user_id}" + + async def test_get_user(self, async_client: AsyncClient, owner_auth_header: dict): + user = await UserFactory.create() + + response = await async_client.get(self.url(user.id), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "id": str(user.id), + "first_name": user.first_name, + "last_name": user.last_name, + "username": user.username, + "role": UserRole.annotator, + "api_key": user.api_key, + "inserted_at": user.inserted_at.isoformat(), + "updated_at": user.updated_at.isoformat(), + } + + async def test_get_user_without_authentication(self, async_client: AsyncClient): + user = await UserFactory.create() + + response = await async_client.get(self.url(user.id)) + + assert response.status_code == 401 + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_get_user_with_unauthorized_role(self, async_client: AsyncClient, user_role: UserRole): + user = await UserFactory.create(role=user_role) + + response = await async_client.get( + self.url(user.id), + headers={API_KEY_HEADER_NAME: user.api_key}, + ) + + assert response.status_code == 403 + + async def test_get_user_with_nonexistent_user_id(self, async_client: AsyncClient, owner_auth_header: dict): + user_id = uuid4() + + response = await async_client.get(self.url(user_id), headers=owner_auth_header) + + assert response.status_code == 404 + assert response.json() == {"detail": f"User with id `{user_id}` not found"}