From 42790e18c2ddd2a4769f8751e554eca2db126af5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 6 May 2024 16:32:54 +0200 Subject: [PATCH 1/6] feat: migrate endpoint /me to /api/v1/me --- src/argilla_server/apis/v1/handlers/users.py | 8 ++++ src/argilla_server/schemas/v1/users.py | 34 +++++++++++++++ tests/unit/api/v1/users/__init__.py | 13 ++++++ .../api/v1/users/test_get_current_user.py | 43 +++++++++++++++++++ 4 files changed, 98 insertions(+) create mode 100644 src/argilla_server/schemas/v1/users.py create mode 100644 tests/unit/api/v1/users/__init__.py create mode 100644 tests/unit/api/v1/users/test_get_current_user.py diff --git a/src/argilla_server/apis/v1/handlers/users.py b/src/argilla_server/apis/v1/handlers/users.py index f0578f65..daf0c1bb 100644 --- a/src/argilla_server/apis/v1/handlers/users.py +++ b/src/argilla_server/apis/v1/handlers/users.py @@ -21,12 +21,20 @@ from argilla_server.database import get_async_db from argilla_server.models import User from argilla_server.policies import UserPolicyV1, authorize +from argilla_server.schemas.v1.users import User as UserSchema from argilla_server.schemas.v1.workspaces import Workspaces from argilla_server.security import auth router = APIRouter(tags=["users"]) +@router.get("/me", response_model=UserSchema) +async def get_current_user(current_user: User = Security(auth.get_current_user)): + # TODO: Should we add telemetry.track_login? + + return current_user + + @router.get("/users/{user_id}/workspaces", response_model=Workspaces) async def list_user_workspaces( *, db: AsyncSession = Depends(get_async_db), user_id: UUID, current_user: User = Security(auth.get_current_user) diff --git a/src/argilla_server/schemas/v1/users.py b/src/argilla_server/schemas/v1/users.py new file mode 100644 index 00000000..f728d7c1 --- /dev/null +++ b/src/argilla_server/schemas/v1/users.py @@ -0,0 +1,34 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from typing import Optional +from uuid import UUID + +from argilla_server.enums import UserRole +from argilla_server.pydantic_v1 import BaseModel + + +class User(BaseModel): + id: UUID + first_name: str + last_name: Optional[str] + username: str + role: UserRole + api_key: str + inserted_at: datetime + updated_at: datetime + + class Config: + orm_mode = True diff --git a/tests/unit/api/v1/users/__init__.py b/tests/unit/api/v1/users/__init__.py new file mode 100644 index 00000000..55be4179 --- /dev/null +++ b/tests/unit/api/v1/users/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/api/v1/users/test_get_current_user.py b/tests/unit/api/v1/users/test_get_current_user.py new file mode 100644 index 00000000..e0d29f93 --- /dev/null +++ b/tests/unit/api/v1/users/test_get_current_user.py @@ -0,0 +1,43 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.models import User +from httpx import AsyncClient + + +@pytest.mark.asyncio +class TestGetCurrentUser: + def url(self) -> str: + return "/api/v1/me" + + async def test_get_current_user(self, async_client: AsyncClient, owner: User, owner_auth_header: dict): + response = await async_client.get(self.url(), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "id": str(owner.id), + "first_name": owner.first_name, + "last_name": owner.last_name, + "username": owner.username, + "role": owner.role, + "api_key": owner.api_key, + "inserted_at": owner.inserted_at.isoformat(), + "updated_at": owner.updated_at.isoformat(), + } + + async def test_get_current_user_without_authentication(self, async_client: AsyncClient): + response = await async_client.get(self.url()) + + assert response.status_code == 401 From 6e4b910adb7f6d0eccb7fbafbe329edb5ca96806 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 6 May 2024 16:40:55 +0200 Subject: [PATCH 2/6] chore: update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 04cdcff4..647a8713 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ These are the section headers that we use: ## [Unreleased]() - Added `POST /api/v1/token` endpoint to generate a new API token for a user. ([#138](https://github.com/argilla-io/argilla-server/pull/138)) +- Added `GET /api/v1/me` endpoint to get the current user information. ([#140](https://github.com/argilla-io/argilla-server/pull/140)) ## [Unreleased]() From 06bd5201c684744550007583f1d325b44f733277 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 6 May 2024 17:00:02 +0200 Subject: [PATCH 3/6] feat: add telemetry.track_login call --- src/argilla_server/apis/v1/handlers/users.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/argilla_server/apis/v1/handlers/users.py b/src/argilla_server/apis/v1/handlers/users.py index daf0c1bb..5860631f 100644 --- a/src/argilla_server/apis/v1/handlers/users.py +++ b/src/argilla_server/apis/v1/handlers/users.py @@ -14,30 +14,34 @@ from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Security, status +from fastapi import APIRouter, Depends, HTTPException, Request, Security, status from sqlalchemy.ext.asyncio import AsyncSession +from argilla_server import models, telemetry from argilla_server.contexts import accounts from argilla_server.database import get_async_db from argilla_server.models import User from argilla_server.policies import UserPolicyV1, authorize -from argilla_server.schemas.v1.users import User as UserSchema +from argilla_server.schemas.v1.users import User from argilla_server.schemas.v1.workspaces import Workspaces from argilla_server.security import auth router = APIRouter(tags=["users"]) -@router.get("/me", response_model=UserSchema) -async def get_current_user(current_user: User = Security(auth.get_current_user)): - # TODO: Should we add telemetry.track_login? +@router.get("/me", response_model=User) +async def get_current_user(request: Request, current_user: models.User = Security(auth.get_current_user)): + await telemetry.track_login(request, current_user) return current_user @router.get("/users/{user_id}/workspaces", response_model=Workspaces) async def list_user_workspaces( - *, db: AsyncSession = Depends(get_async_db), user_id: UUID, current_user: User = Security(auth.get_current_user) + *, + db: AsyncSession = Depends(get_async_db), + user_id: UUID, + current_user: models.User = Security(auth.get_current_user), ): await authorize(current_user, UserPolicyV1.list_workspaces) From 4fb7e7361f11b8f9d86bcbf0243d2b7d0ad1c3c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 6 May 2024 18:04:27 +0200 Subject: [PATCH 4/6] feat: migrate /users endpoint to /api/v1/users --- src/argilla_server/apis/v1/handlers/users.py | 17 +++- src/argilla_server/contexts/accounts.py | 2 + src/argilla_server/policies.py | 4 + src/argilla_server/schemas/v1/users.py | 6 +- tests/unit/api/v1/users/test_list_users.py | 81 ++++++++++++++++++++ 5 files changed, 107 insertions(+), 3 deletions(-) create mode 100644 tests/unit/api/v1/users/test_list_users.py diff --git a/src/argilla_server/apis/v1/handlers/users.py b/src/argilla_server/apis/v1/handlers/users.py index 5860631f..94c79808 100644 --- a/src/argilla_server/apis/v1/handlers/users.py +++ b/src/argilla_server/apis/v1/handlers/users.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Request, Security, status @@ -20,9 +21,8 @@ from argilla_server import models, telemetry from argilla_server.contexts import accounts from argilla_server.database import get_async_db -from argilla_server.models import User from argilla_server.policies import UserPolicyV1, authorize -from argilla_server.schemas.v1.users import User +from argilla_server.schemas.v1.users import User, Users from argilla_server.schemas.v1.workspaces import Workspaces from argilla_server.security import auth @@ -36,6 +36,19 @@ async def get_current_user(request: Request, current_user: models.User = Securit return current_user +@router.get("/users", response_model=Users) +async def list_users( + *, + db: AsyncSession = Depends(get_async_db), + current_user: models.User = Security(auth.get_current_user), +): + await authorize(current_user, UserPolicyV1.list) + + users = await accounts.list_users(db) + + return Users(items=users) + + @router.get("/users/{user_id}/workspaces", response_model=Workspaces) async def list_user_workspaces( *, diff --git a/src/argilla_server/contexts/accounts.py b/src/argilla_server/contexts/accounts.py index b128f07b..7e0ef28b 100644 --- a/src/argilla_server/contexts/accounts.py +++ b/src/argilla_server/contexts/accounts.py @@ -110,6 +110,8 @@ async def get_user_by_api_key(db: AsyncSession, api_key: str) -> Union[User, Non async def list_users(db: "AsyncSession") -> Sequence[User]: + # TODO: After removing API v0 implementation we can remove the workspaces eager loading + # because is not used in the new API v1 endpoints. result = await db.execute(select(User).order_by(User.inserted_at.asc()).options(selectinload(User.workspaces))) return result.scalars().all() diff --git a/src/argilla_server/policies.py b/src/argilla_server/policies.py index 196ef159..97dbf183 100644 --- a/src/argilla_server/policies.py +++ b/src/argilla_server/policies.py @@ -129,6 +129,10 @@ async def is_allowed(actor: User) -> bool: class UserPolicyV1: + @classmethod + async def list(cls, actor: User) -> bool: + return actor.is_owner + @classmethod async def list_workspaces(cls, actor: User) -> bool: return actor.is_owner diff --git a/src/argilla_server/schemas/v1/users.py b/src/argilla_server/schemas/v1/users.py index f728d7c1..4ad4b133 100644 --- a/src/argilla_server/schemas/v1/users.py +++ b/src/argilla_server/schemas/v1/users.py @@ -13,7 +13,7 @@ # limitations under the License. from datetime import datetime -from typing import Optional +from typing import List, Optional from uuid import UUID from argilla_server.enums import UserRole @@ -32,3 +32,7 @@ class User(BaseModel): class Config: orm_mode = True + + +class Users(BaseModel): + items: List[User] diff --git a/tests/unit/api/v1/users/test_list_users.py b/tests/unit/api/v1/users/test_list_users.py new file mode 100644 index 00000000..43a09bfe --- /dev/null +++ b/tests/unit/api/v1/users/test_list_users.py @@ -0,0 +1,81 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from argilla_server.constants import API_KEY_HEADER_NAME +from argilla_server.enums import UserRole +from argilla_server.models import User +from httpx import AsyncClient + +from tests.factories import UserFactory + + +@pytest.mark.asyncio +class TestListUsers: + def url(self) -> str: + return "/api/v1/users" + + async def test_list_users(self, async_client: AsyncClient, owner: User, owner_auth_header: dict): + user_a, user_b = await UserFactory.create_batch(2) + + response = await async_client.get(self.url(), headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "id": str(owner.id), + "first_name": owner.first_name, + "last_name": owner.last_name, + "username": owner.username, + "role": owner.role, + "api_key": owner.api_key, + "inserted_at": owner.inserted_at.isoformat(), + "updated_at": owner.updated_at.isoformat(), + }, + { + "id": str(user_a.id), + "first_name": user_a.first_name, + "last_name": user_a.last_name, + "username": user_a.username, + "role": user_a.role, + "api_key": user_a.api_key, + "inserted_at": user_a.inserted_at.isoformat(), + "updated_at": user_a.updated_at.isoformat(), + }, + { + "id": str(user_b.id), + "first_name": user_b.first_name, + "last_name": user_b.last_name, + "username": user_b.username, + "role": user_b.role, + "api_key": user_b.api_key, + "inserted_at": user_b.inserted_at.isoformat(), + "updated_at": user_b.updated_at.isoformat(), + }, + ] + } + + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) + async def test_list_users_with_invalid_role(self, async_client: AsyncClient, user_role: UserRole): + user = await UserFactory.create(role=user_role) + + response = await async_client.get(self.url(), headers={API_KEY_HEADER_NAME: user.api_key}) + + assert response.status_code == 403 + + async def test_list_users_without_authentication(self, async_client: AsyncClient): + response = await async_client.get(self.url()) + + assert response.status_code == 401 From 045342099fe2cc7dce566587ca32f4b06541f841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 6 May 2024 18:16:16 +0200 Subject: [PATCH 5/6] chore: update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 647a8713..96c7d4bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ These are the section headers that we use: - Added `POST /api/v1/token` endpoint to generate a new API token for a user. ([#138](https://github.com/argilla-io/argilla-server/pull/138)) - Added `GET /api/v1/me` endpoint to get the current user information. ([#140](https://github.com/argilla-io/argilla-server/pull/140)) +- Added `GET /api/v1/users` endpoint to get a list of all users. ([#142](https://github.com/argilla-io/argilla-server/pull/142)) ## [Unreleased]() From 085607e5265b639aa4e0a18c79f5181d121ae158 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Fri, 10 May 2024 13:27:26 +0200 Subject: [PATCH 6/6] feat: small tests improvement --- tests/unit/api/v1/users/test_list_users.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/api/v1/users/test_list_users.py b/tests/unit/api/v1/users/test_list_users.py index 43a09bfe..358d710d 100644 --- a/tests/unit/api/v1/users/test_list_users.py +++ b/tests/unit/api/v1/users/test_list_users.py @@ -67,15 +67,15 @@ async def test_list_users(self, async_client: AsyncClient, owner: User, owner_au ] } + async def test_list_users_without_authentication(self, async_client: AsyncClient): + response = await async_client.get(self.url()) + + assert response.status_code == 401 + @pytest.mark.parametrize("user_role", [UserRole.admin, UserRole.annotator]) - async def test_list_users_with_invalid_role(self, async_client: AsyncClient, user_role: UserRole): + async def test_list_users_with_unauthorized_role(self, async_client: AsyncClient, user_role: UserRole): user = await UserFactory.create(role=user_role) response = await async_client.get(self.url(), headers={API_KEY_HEADER_NAME: user.api_key}) assert response.status_code == 403 - - async def test_list_users_without_authentication(self, async_client: AsyncClient): - response = await async_client.get(self.url()) - - assert response.status_code == 401