Skip to content
This repository has been archived by the owner on Jun 14, 2024. It is now read-only.

feat: add GET /api/v1/users/:user_id new endpoint #166

Merged
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ These are the section headers that we use:
## [Unreleased]()

- Added `POST /api/v1/token` endpoint to generate a new API token for a user. ([#138](https://github.com/argilla-io/argilla-server/pull/138))
- Added `GET /api/v1/me` endpoint to get the current user information. ([#140](https://github.com/argilla-io/argilla-server/pull/140))
- Added `GET /api/v1/users` endpoint to get a list of all users. ([#142](https://github.com/argilla-io/argilla-server/pull/142))
- Added `POST /api/v1/users` endpoint to create a new user. ([#146](https://github.com/argilla-io/argilla-server/pull/146))
- Added `DELETE /api/v1/users` endpoint to delete a user. ([#148](https://github.com/argilla-io/argilla-server/pull/148))
- Added `GET /api/v1/users/:user_id` endpoint to get a specific user. ([#166](https://github.com/argilla-io/argilla-server/pull/166))

## [Unreleased]()

Expand Down
2 changes: 1 addition & 1 deletion src/argilla_server/apis/v0/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def create_user(
raise EntityAlreadyExistsError(name=user_create.username, type=User)

try:
user = await accounts.create_user(db, user_create)
user = await accounts.create_user(db, user_create.dict(), user_create.workspaces)
telemetry.track_user_created(user)
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))
Expand Down
96 changes: 93 additions & 3 deletions src/argilla_server/apis/v1/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,114 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Security, status
from fastapi import APIRouter, Depends, HTTPException, Request, Security, status
from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server import models, telemetry
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.models import User
from argilla_server.errors import EntityAlreadyExistsError, EntityNotFoundError
from argilla_server.policies import UserPolicyV1, authorize
from argilla_server.schemas.v1.users import User, UserCreate, Users
from argilla_server.schemas.v1.workspaces import Workspaces
from argilla_server.security import auth

router = APIRouter(tags=["users"])


@router.get("/me", response_model=User)
async def get_current_user(request: Request, current_user: models.User = Security(auth.get_current_user)):
await telemetry.track_login(request, current_user)

return current_user


@router.get("/users/{user_id}", response_model=User)
async def get_user(
*,
db: AsyncSession = Depends(get_async_db),
user_id: UUID,
current_user: models.User = Security(auth.get_current_user),
):
await authorize(current_user, UserPolicyV1.get)

user = await accounts.get_user_by_id(db, user_id)
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User with id `{user_id}` not found",
)

return user


@router.get("/users", response_model=Users)
async def list_users(
*,
db: AsyncSession = Depends(get_async_db),
current_user: models.User = Security(auth.get_current_user),
):
await authorize(current_user, UserPolicyV1.list)

users = await accounts.list_users(db)

return Users(items=users)


@router.post("/users", response_model=User)
async def create_user(
*,
db: AsyncSession = Depends(get_async_db),
user_create: UserCreate,
current_user: models.User = Security(auth.get_current_user),
):
await authorize(current_user, UserPolicyV1.create)

user = await accounts.get_user_by_username(db, user_create.username)
if user is not None:
raise EntityAlreadyExistsError(name=user_create.username, type=User)

try:
user = await accounts.create_user(db, user_create.dict())

telemetry.track_user_created(user)
except Exception as e:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(e))

return user


@router.delete("/users/{user_id}", response_model=User)
async def delete_user(
*,
db: AsyncSession = Depends(get_async_db),
user_id: UUID,
current_user: models.User = Security(auth.get_current_user),
):
user = await accounts.get_user_by_id(db, user_id)
if user is None:
# TODO: Forcing here user_id to be an string.
# Not casting it is causing a `Object of type UUID is not JSON serializable`.
# Possible solution redefining JSONEncoder.default here:
# https://github.com/jazzband/django-push-notifications/issues/586
raise EntityNotFoundError(name=str(user_id), type=User)

await authorize(current_user, UserPolicyV1.delete)

await accounts.delete_user(db, user)

return user


@router.get("/users/{user_id}/workspaces", response_model=Workspaces)
async def list_user_workspaces(
*, db: AsyncSession = Depends(get_async_db), user_id: UUID, current_user: User = Security(auth.get_current_user)
*,
db: AsyncSession = Depends(get_async_db),
user_id: UUID,
current_user: models.User = Security(auth.get_current_user),
):
await authorize(current_user, UserPolicyV1.list_workspaces)

Expand Down
38 changes: 23 additions & 15 deletions src/argilla_server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ async def get_user_by_api_key(db: AsyncSession, api_key: str) -> Union[User, Non


async def list_users(db: "AsyncSession") -> Sequence[User]:
# TODO: After removing API v0 implementation we can remove the workspaces eager loading
# because is not used in the new API v1 endpoints.
result = await db.execute(select(User).order_by(User.inserted_at.asc()).options(selectinload(User.workspaces)))
return result.scalars().all()

Expand All @@ -119,23 +121,26 @@ async def list_users_by_ids(db: AsyncSession, ids: Iterable[UUID]) -> Sequence[U
return result.scalars().all()


async def create_user(db: "AsyncSession", user_create: UserCreate) -> User:
# TODO: After removing API v0 implementation we can remove the workspaces attribute.
# With API v1 the workspaces will be created doing additional requests to other endpoints for it.
async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List[str], None] = None) -> User:
async with db.begin_nested():
user = await User.create(
db,
first_name=user_create.first_name,
last_name=user_create.last_name,
username=user_create.username,
role=user_create.role,
password_hash=hash_password(user_create.password),
first_name=user_attrs["first_name"],
last_name=user_attrs["last_name"],
username=user_attrs["username"],
role=user_attrs["role"],
password_hash=hash_password(user_attrs["password"]),
autocommit=False,
)

if user_create.workspaces:
for workspace_name in user_create.workspaces:
if workspaces is not None:
for workspace_name in workspaces:
workspace = await get_workspace_by_name(db, workspace_name)
if not workspace:
raise ValueError(f"Workspace '{workspace_name}' does not exist")

await WorkspaceUser.create(
db,
workspace_id=workspace.id,
Expand All @@ -152,15 +157,18 @@ async def create_user_with_random_password(
db,
username: str,
first_name: str,
workspaces: List[str] = None,
role: UserRole = UserRole.annotator,
workspaces: Union[List[str], None] = None,
) -> User:
password = _generate_random_password()

user_create = UserCreate(
first_name=first_name, username=username, role=role, password=password, workspaces=workspaces
)
return await create_user(db, user_create)
user_attrs = {
"first_name": first_name,
"last_name": None,
"username": username,
"role": role,
"password": _generate_random_password(),
}

return await create_user(db, user_attrs, workspaces)


async def delete_user(db: AsyncSession, user: User) -> User:
Expand Down
16 changes: 16 additions & 0 deletions src/argilla_server/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ async def is_allowed(actor: User) -> bool:


class UserPolicyV1:
@classmethod
async def get(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def list(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def create(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def delete(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def list_workspaces(cls, actor: User) -> bool:
return actor.is_owner
Expand Down
52 changes: 52 additions & 0 deletions src/argilla_server/schemas/v1/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from typing import List, Optional
from uuid import UUID

from argilla_server.enums import UserRole
from argilla_server.pydantic_v1 import BaseModel, Field, constr

USER_USERNAME_REGEX = "^(?!-|_)[A-za-z0-9-_]+$"
USER_PASSWORD_MIN_LENGTH = 8
USER_PASSWORD_MAX_LENGTH = 100


class User(BaseModel):
id: UUID
first_name: str
last_name: Optional[str]
username: str
role: UserRole
# TODO: We need to move `api_key` outside of this schema and think about a more
# secure way to expose it, along with ways to expire it and create new API keys.
api_key: str
inserted_at: datetime
updated_at: datetime

class Config:
orm_mode = True


class UserCreate(BaseModel):
first_name: constr(min_length=1, strip_whitespace=True)
last_name: Optional[constr(min_length=1, strip_whitespace=True)]
username: str = Field(regex=USER_USERNAME_REGEX, min_length=1)
role: Optional[UserRole]
password: str = Field(min_length=USER_PASSWORD_MIN_LENGTH, max_length=USER_PASSWORD_MAX_LENGTH)


class Users(BaseModel):
items: List[User]
47 changes: 47 additions & 0 deletions tests/unit/api/v0/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,53 @@ async def test_create_user_with_non_default_role(
assert response_body["role"] == UserRole.owner.value


@pytest.mark.asyncio
async def test_create_user_with_first_name_including_leading_and_trailing_spaces(
async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
):
response = await async_client.post(
"/api/users",
headers=owner_auth_header,
json={
"first_name": " First name ",
"username": "username",
"password": "12345678",
},
)

assert response.status_code == 200

assert (await db.execute(select(func.count(User.id)))).scalar() == 2
user = (await db.execute(select(User).filter_by(username="username"))).scalar_one()

assert response.json()["first_name"] == "First name"
assert user.first_name == "First name"


@pytest.mark.asyncio
async def test_create_user_with_last_name_including_leading_and_trailing_spaces(
async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
):
response = await async_client.post(
"/api/users",
headers=owner_auth_header,
json={
"first_name": "First name",
"last_name": " Last name ",
"username": "username",
"password": "12345678",
},
)

assert response.status_code == 200

assert (await db.execute(select(func.count(User.id)))).scalar() == 2
user = (await db.execute(select(User).filter_by(username="username"))).scalar_one()

assert response.json()["last_name"] == "Last name"
assert user.last_name == "Last name"


@pytest.mark.asyncio
async def test_create_user_without_authentication(async_client: "AsyncClient", db: "AsyncSession"):
user = {"first_name": "first-name", "username": "username", "password": "12345678"}
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/api/v1/users/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading