Skip to content

Commit

Permalink
refactor: review security module (#4426)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

In this PR, the `argilla.server.security` module is reviewed, by
changing the `local` provider to `db` since is using db for validation.
Also, the base `AuthProvider` methods are decorated as `abstract`, and a
new class method is used to create provider instances, to normalize the
provider creation process.

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] New feature (non-breaking change which adds functionality)
- [X] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

Test with a local deployment

**Checklist**

- [ ] I added relevant documentation
- [X] I followed the style guidelines of this project
- [X] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
frascuchon and pre-commit-ci[bot] authored Dec 22, 2023
1 parent 49a2600 commit 0e93e7f
Show file tree
Hide file tree
Showing 13 changed files with 90 additions and 84 deletions.
2 changes: 1 addition & 1 deletion src/argilla/cli/server/database/users/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from argilla.cli.server.database.users.utils import get_or_new_workspace
from argilla.server.database import AsyncSessionLocal
from argilla.server.models import User, UserRole
from argilla.server.security.auth_provider.local.settings import settings
from argilla.server.security.auth_provider.db.settings import settings
from argilla.server.security.model import USER_USERNAME_REGEX, WORKSPACE_NAME_REGEX

if TYPE_CHECKING:
Expand Down
3 changes: 1 addition & 2 deletions src/argilla/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ async def setup_elasticsearch():


def configure_app_security(app: FastAPI):
if hasattr(auth, "router"):
app.include_router(auth.router)
auth.configure_app(app)


def configure_app_logging(app: FastAPI):
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def delete_user(db: "AsyncSession", user: User) -> User:
return await user.delete(db)


async def authenticate_user(db: Session, username: str, password: str):
async def authenticate_user(db: "AsyncSession", username: str, password: str):
user = await get_user_by_username(db, username)

if user and verify_password(password, user.password_hash):
Expand Down
5 changes: 4 additions & 1 deletion src/argilla/server/security/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .auth_provider import DBAuthProvider
from .auth_provider.base import AuthProvider, api_key_header
from .model import User

from .factory import auth
auth = DBAuthProvider.new_instance()
2 changes: 1 addition & 1 deletion src/argilla/server/security/auth_provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .local import LocalAuthProvider, create_local_auth_provider
from .db import DBAuthProvider, settings # noqa
21 changes: 16 additions & 5 deletions src/argilla/server/security/auth_provider/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABCMeta, abstractmethod
from typing import Optional

from fastapi import Depends
from fastapi import Depends, FastAPI, Request
from fastapi.security import APIKeyHeader, SecurityScopes

from argilla._constants import API_KEY_HEADER_NAME
Expand All @@ -24,13 +24,24 @@
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)


class AuthProvider:
class AuthProvider(metaclass=ABCMeta):
"""Base class for auth provider"""

async def get_user(
@classmethod
@abstractmethod
def new_instance(cls):
pass

@abstractmethod
def configure_app(self, app: FastAPI):
pass

@abstractmethod
async def get_current_user(
self,
security_scopes: SecurityScopes,
request: Request,
api_key: Optional[str] = Depends(api_key_header),
**kwargs,
) -> User:
raise NotImplementedError()
pass
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .provider import LocalAuthProvider, create_local_auth_provider
from .provider import DBAuthProvider
from .settings import settings
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from datetime import datetime, timedelta
from typing import Optional

from fastapi import APIRouter, Depends
from fastapi import Depends, FastAPI, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes
from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -34,16 +34,18 @@
_oauth2_scheme = OAuth2PasswordBearer(tokenUrl=local_security.public_oauth_token_url, auto_error=False)


class LocalAuthProvider(AuthProvider):
class DBAuthProvider(AuthProvider):
def __init__(self, settings: Settings):
self.router = APIRouter(tags=["security"])
self.settings = settings

# TODO: maybe it's better if we move this endpoint to apis/v0/handlers
@self.router.post(
settings.token_api_url,
response_model=Token,
operation_id="login_for_access_token",
@classmethod
def new_instance(cls) -> "DBAuthProvider":
settings = Settings()
return DBAuthProvider(settings=settings)

def configure_app(self, app: FastAPI):
@app.post(
self.settings.token_api_url, response_model=Token, operation_id="login_for_access_token", tags=["security"]
)
async def login_for_access_token(
db: AsyncSession = Depends(get_async_db),
Expand All @@ -52,36 +54,31 @@ async def login_for_access_token(
user = await accounts.authenticate_user(db, form_data.username, form_data.password)
if not user:
raise UnauthorizedError()

access_token_expires = timedelta(minutes=self.settings.token_expiration_in_minutes)
access_token = self._create_access_token(user.username, expires_delta=access_token_expires)

return Token(access_token=access_token)

def _create_access_token(self, username: str, expires_delta: Optional[timedelta] = None) -> str:
"""
Creates an access token
async def get_current_user(
self,
security_scopes: SecurityScopes,
request: Request,
db: AsyncSession = Depends(get_async_db),
api_key: Optional[str] = Depends(api_key_header),
token: Optional[str] = Depends(_oauth2_scheme),
) -> User:
user = None

Parameters
----------
username:
The user name
expires_delta:
Token expiration
if api_key:
user = await accounts.get_user_by_api_key(db, api_key)
elif token:
user = await self.fetch_token_user(db, token)

Returns
-------
An access token string
"""
to_encode = {
"sub": username,
}
if expires_delta:
to_encode["exp"] = datetime.utcnow() + expires_delta
if user is None:
raise UnauthorizedError()

return jwt.encode(
to_encode,
self.settings.secret_key,
algorithm=self.settings.algorithm,
)
return user

async def fetch_token_user(self, db: AsyncSession, token: str) -> Optional[User]:
"""
Expand All @@ -97,39 +94,35 @@ async def fetch_token_user(self, db: AsyncSession, token: str) -> Optional[User]
An User instance if a valid token was provided. None otherwise
"""
try:
payload = jwt.decode(
token,
self.settings.secret_key,
algorithms=[self.settings.algorithm],
)
payload = jwt.decode(token, self.settings.secret_key, algorithms=[self.settings.algorithm])
username: str = payload.get("sub")
if username:
user = await accounts.get_user_by_username(db, username)
return user
except JWTError:
return None

async def get_current_user(
self,
security_scopes: SecurityScopes,
db: AsyncSession = Depends(get_async_db),
api_key: Optional[str] = Depends(api_key_header),
token: Optional[str] = Depends(_oauth2_scheme),
) -> User:
user = None

if api_key:
user = await accounts.get_user_by_api_key(db, api_key)
elif token:
user = await self.fetch_token_user(db, token)

if user is None:
raise UnauthorizedError()

return user
def _create_access_token(self, username: str, expires_delta: Optional[timedelta] = None) -> str:
"""
Creates an access token
Parameters
----------
username:
The user name
expires_delta:
Token expiration
def create_local_auth_provider():
settings = Settings()
Returns
-------
An access token string
"""
to_encode = {"sub": username}
if expires_delta:
to_encode["exp"] = datetime.utcnow() + expires_delta

return LocalAuthProvider(settings=settings)
return jwt.encode(
to_encode,
self.settings.secret_key,
algorithm=self.settings.algorithm,
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,7 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla.server.security.auth_provider import create_local_auth_provider

auth = create_local_auth_provider()
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@
from argilla.client.sdk.users.models import UserRole
from argilla.client.workspaces import Workspace
from argilla.server.models import User as ServerUser
from argilla.server.settings import settings
from sqlalchemy.ext.asyncio import AsyncSession

from argilla.server.settings import settings
from tests.factories import (
DatasetFactory,
RecordFactory,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/cli/server/database/users/test_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING

from argilla.server.models import User, UserRole, Workspace, WorkspaceUser
from argilla.server.security.auth_provider.local.settings import settings
from argilla.server.security.auth_provider.db.settings import settings
from click.testing import CliRunner
from typer import Typer

Expand Down
20 changes: 12 additions & 8 deletions tests/unit/server/security/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,39 @@

import pytest
from argilla._constants import DEFAULT_API_KEY
from argilla.server.security.auth_provider.local.provider import create_local_auth_provider
from argilla.server.security.auth_provider.db import DBAuthProvider
from fastapi.security import SecurityScopes

if TYPE_CHECKING:
from argilla.server.models import User
from sqlalchemy.ext.asyncio import AsyncSession

localAuth = create_local_auth_provider()
security_Scopes = SecurityScopes
db_auth = DBAuthProvider.new_instance()
security_Scopes = SecurityScopes()


# Tests for function get_user via token and api key
@pytest.mark.asyncio
async def test_get_user_via_token(db: "AsyncSession", argilla_user: "User"):
access_token = localAuth._create_access_token(username=argilla_user.username)
access_token = db_auth._create_access_token(username=argilla_user.username)

user = await localAuth.get_current_user(security_scopes=security_Scopes, db=db, token=access_token, api_key=None)
user = await db_auth.get_current_user(
security_scopes=security_Scopes, request=None, db=db, token=access_token, api_key=None
)
assert user.username == "argilla"


@pytest.mark.asyncio
async def test_get_user_via_api_key(db: "AsyncSession", argilla_user: "User"):
user = await localAuth.get_current_user(security_scopes=security_Scopes, db=db, api_key=DEFAULT_API_KEY, token=None)
user = await db_auth.get_current_user(
security_scopes=security_Scopes, request=None, db=db, api_key=DEFAULT_API_KEY, token=None
)
assert user.username == "argilla"


# Test for function fetch token
@pytest.mark.asyncio
async def test_fetch_token_user(db: "AsyncSession", argilla_user: "User"):
access_token = localAuth._create_access_token(username=argilla_user.username)
user = await localAuth.fetch_token_user(db=db, token=access_token)
access_token = db_auth._create_access_token(username=argilla_user.username)
user = await db_auth.fetch_token_user(db=db, token=access_token)
assert user.username == "argilla"

0 comments on commit 0e93e7f

Please sign in to comment.