Skip to content

Commit

Permalink
refactor: move users and workspaces schemas under schemas module (#4532)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR moves user and workspace schemas defined in the `security`
module to the proper `schemas.v0` module.

Also, the `users_file` setting attribute is removed, but the
corresponding environment variable is still available from the `user
migrate` command. This change does not change the current behaviour.

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] New feature (non-breaking change which adds functionality)
- [X] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

Running tests locally

**Checklist**

- [ ] I added relevant documentation
- [X] I followed the style guidelines of this project
- [X] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
frascuchon and pre-commit-ci[bot] authored Jan 25, 2024
1 parent 75c8461 commit 123ddf1
Show file tree
Hide file tree
Showing 15 changed files with 192 additions and 186 deletions.
2 changes: 1 addition & 1 deletion src/argilla/server/apis/v0/handlers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from argilla.server.errors import EntityAlreadyExistsError, EntityNotFoundError
from argilla.server.policies import UserPolicy, authorize
from argilla.server.pydantic_v1 import parse_obj_as
from argilla.server.schemas.v0.users import User, UserCreate
from argilla.server.security import auth
from argilla.server.security.model import User, UserCreate

router = APIRouter(tags=["users"])

Expand Down
3 changes: 2 additions & 1 deletion src/argilla/server/apis/v0/handlers/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from argilla.server.errors import EntityAlreadyExistsError, EntityNotFoundError
from argilla.server.policies import WorkspacePolicy, WorkspaceUserPolicy, authorize
from argilla.server.pydantic_v1 import parse_obj_as
from argilla.server.schemas.v0.users import User
from argilla.server.schemas.v0.workspaces import Workspace, WorkspaceCreate, WorkspaceUserCreate
from argilla.server.security import auth
from argilla.server.security.model import User, Workspace, WorkspaceCreate, WorkspaceUserCreate

router = APIRouter(tags=["workspaces"])

Expand Down
7 changes: 2 additions & 5 deletions src/argilla/server/cli/database/users/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
from argilla.server.contexts import accounts
from argilla.server.database import AsyncSessionLocal
from argilla.server.models import User, UserRole
from argilla.server.security.model import (
USER_PASSWORD_MIN_LENGTH,
UserCreate,
WorkspaceCreate,
)
from argilla.server.schemas.v0.users import USER_PASSWORD_MIN_LENGTH, UserCreate
from argilla.server.schemas.v0.workspaces import WorkspaceCreate

from .utils import get_or_new_workspace

Expand Down
11 changes: 7 additions & 4 deletions src/argilla/server/cli/database/users/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from typing import TYPE_CHECKING, List, Optional

import typer
import yaml

from argilla.pydantic_v1 import BaseModel, Field, constr
from argilla.server.database import AsyncSessionLocal
from argilla.server.models import User, UserRole
from argilla.server.security.auth_provider.db.settings import settings
from argilla.server.security.model import USER_USERNAME_REGEX, WORKSPACE_NAME_REGEX
from argilla.server.pydantic_v1 import BaseModel, Field, constr
from argilla.server.schemas.v0.users import USER_USERNAME_REGEX
from argilla.server.schemas.v0.workspaces import WORKSPACE_NAME_REGEX

from .utils import get_or_new_workspace

Expand Down Expand Up @@ -107,7 +108,9 @@ def _user_workspace_names(self, user: dict) -> List[str]:

def migrate():
"""Migrate users defined in YAML file to database."""
asyncio.run(UsersMigrator(settings.users_db_file).migrate())

users_db_file: str = os.getenv("ARGILLA_LOCAL_AUTH_USERS_DB_FILE", ".users.yml")
asyncio.run(UsersMigrator(users_db_file).migrate())


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List, Union
from uuid import UUID

Expand All @@ -20,7 +19,8 @@
from sqlalchemy.orm import Session, selectinload

from argilla.server.models import User, Workspace, WorkspaceUser
from argilla.server.security.model import UserCreate, WorkspaceCreate, WorkspaceUserCreate
from argilla.server.schemas.v0.users import UserCreate
from argilla.server.schemas.v0.workspaces import WorkspaceCreate, WorkspaceUserCreate

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
VectorSettings,
)
from argilla.server.models.suggestions import SuggestionCreateWithRecordId
from argilla.server.schemas.v0.users import User
from argilla.server.schemas.v1.datasets import (
DatasetCreate,
)
Expand Down Expand Up @@ -79,7 +80,6 @@
)
from argilla.server.schemas.v1.vectors import Vector as VectorSchema
from argilla.server.search_engine import SearchEngine
from argilla.server.security.model import User

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down
64 changes: 64 additions & 0 deletions src/argilla/server/schemas/v0/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime
from typing import Any, List, Optional
from uuid import UUID

from argilla.server.constants import ES_INDEX_REGEX_PATTERN
from argilla.server.enums import UserRole
from argilla.server.pydantic_v1 import BaseModel, Field, constr
from argilla.server.pydantic_v1.utils import GetterDict

USER_USERNAME_REGEX = ES_INDEX_REGEX_PATTERN
USER_PASSWORD_MIN_LENGTH = 8
USER_PASSWORD_MAX_LENGTH = 100


class UserGetter(GetterDict):
def get(self, key: str, default: Any = None) -> Any:
if key == "full_name":
return f"{self._obj.first_name} {self._obj.last_name}" if self._obj.last_name else self._obj.first_name
elif key == "workspaces":
return [workspace.name for workspace in self._obj.workspaces]
else:
return super().get(key, default)


class User(BaseModel):
"""Base user schema"""

id: UUID
first_name: str
last_name: Optional[str]
full_name: Optional[str] = Field(description="Deprecated. Use `first_name` and `last_name` instead")
username: str = Field()
role: UserRole
workspaces: Optional[List[str]]
api_key: str
inserted_at: datetime
updated_at: datetime

class Config:
orm_mode = True
getter_dict = UserGetter


class UserCreate(BaseModel):
first_name: constr(min_length=1, strip_whitespace=True)
last_name: Optional[constr(min_length=1, strip_whitespace=True)]
username: str = Field(regex=USER_USERNAME_REGEX, min_length=1)
role: Optional[UserRole]
password: constr(min_length=USER_PASSWORD_MIN_LENGTH, max_length=USER_PASSWORD_MAX_LENGTH)
workspaces: Optional[List[str]]
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla.server.errors import ForbiddenOperationError
from argilla.server.security.model import User
from datetime import datetime
from uuid import UUID

from argilla.server.constants import ES_INDEX_REGEX_PATTERN
from argilla.server.pydantic_v1 import BaseModel, Field

def validate_is_super_user(user: User, message: str = None):
"""Common validation to ensure the current user is a admin/superuser"""
if not user.is_superuser():
raise ForbiddenOperationError(message or "Only admin users can apply this change")
WORKSPACE_NAME_REGEX = ES_INDEX_REGEX_PATTERN


class Workspace(BaseModel):
id: UUID
name: str
inserted_at: datetime
updated_at: datetime

class Config:
orm_mode = True


class WorkspaceUserCreate(BaseModel):
user_id: UUID
workspace_id: UUID


class WorkspaceCreate(BaseModel):
name: str = Field(..., regex=WORKSPACE_NAME_REGEX, min_length=1)
1 change: 0 additions & 1 deletion src/argilla/server/security/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@
# limitations under the License.
from .auth_provider import DBAuthProvider
from .auth_provider.base import AuthProvider, api_key_header
from .model import User

auth = DBAuthProvider.new_instance()
1 change: 0 additions & 1 deletion src/argilla/server/security/auth_provider/db/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class Settings(BaseSettings):
algorithm: str = "HS256"
token_expiration_in_minutes: int = 30000
token_api_url: str = "/api/security/token"
users_db_file: str = ".users.yml"

@property
def public_oauth_token_url(self):
Expand Down
71 changes: 1 addition & 70 deletions src/argilla/server/security/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime
from typing import Any, List, Optional
from uuid import UUID

from argilla.server.constants import ES_INDEX_REGEX_PATTERN
from argilla.server.models import UserRole
from argilla.server.pydantic_v1 import BaseModel, Field, constr
from argilla.server.pydantic_v1.utils import GetterDict

WORKSPACE_NAME_REGEX = ES_INDEX_REGEX_PATTERN

USER_USERNAME_REGEX = ES_INDEX_REGEX_PATTERN
USER_PASSWORD_MIN_LENGTH = 8
USER_PASSWORD_MAX_LENGTH = 100


class WorkspaceUserCreate(BaseModel):
user_id: UUID
workspace_id: UUID


class Workspace(BaseModel):
id: UUID
name: str
inserted_at: datetime
updated_at: datetime

class Config:
orm_mode = True


class WorkspaceCreate(BaseModel):
name: str = Field(..., regex=WORKSPACE_NAME_REGEX, min_length=1)


class UserCreate(BaseModel):
first_name: constr(min_length=1, strip_whitespace=True)
last_name: Optional[constr(min_length=1, strip_whitespace=True)]
username: str = Field(regex=USER_USERNAME_REGEX, min_length=1)
role: Optional[UserRole]
password: constr(min_length=USER_PASSWORD_MIN_LENGTH, max_length=USER_PASSWORD_MAX_LENGTH)
workspaces: Optional[List[str]]


class UserGetter(GetterDict):
def get(self, key: str, default: Any = None) -> Any:
if key == "full_name":
return f"{self._obj.first_name} {self._obj.last_name}" if self._obj.last_name else self._obj.first_name
elif key == "workspaces":
return [workspace.name for workspace in self._obj.workspaces]
else:
return super().get(key, default)


class User(BaseModel):
"""Base user model"""

id: UUID
first_name: str
last_name: Optional[str]
full_name: Optional[str] = Field(description="Deprecated. Use `first_name` and `last_name` instead")
username: str = Field()
role: UserRole
workspaces: Optional[List[str]]
api_key: str
inserted_at: datetime
updated_at: datetime

class Config:
orm_mode = True
getter_dict = UserGetter
from argilla.server.pydantic_v1 import BaseModel


class Token(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_datasets_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
load_dataset_settings,
)
from argilla.server.contexts import accounts
from argilla.server.security.model import WorkspaceUserCreate
from argilla.server.schemas.v0.workspaces import WorkspaceUserCreate

from tests.integration.utils import delete_ignoring_errors

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/client/sdk/models/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from argilla.client.sdk.users.models import UserModel as ClientUser
from argilla.client.sdk.users.models import UserRole as ClientUserRole
from argilla.server.models import UserRole as ServerUserRole
from argilla.server.security.model import User as ServerUser
from argilla.server.schemas.v0.users import User as ServerUser

from tests.unit.client.sdk.models.conftest import Helpers

Expand Down
Loading

0 comments on commit 123ddf1

Please sign in to comment.