diff --git a/src/argilla/cli/server/database/users/migrate.py b/src/argilla/cli/server/database/users/migrate.py index 6d18d3a14a..3efb8d54df 100644 --- a/src/argilla/cli/server/database/users/migrate.py +++ b/src/argilla/cli/server/database/users/migrate.py @@ -22,7 +22,7 @@ from argilla.cli.server.database.users.utils import get_or_new_workspace from argilla.server.database import AsyncSessionLocal from argilla.server.models import User, UserRole -from argilla.server.security.auth_provider.local.settings import settings +from argilla.server.security.auth_provider.db.settings import settings from argilla.server.security.model import USER_USERNAME_REGEX, WORKSPACE_NAME_REGEX if TYPE_CHECKING: diff --git a/src/argilla/server/app.py b/src/argilla/server/app.py index f0236ae48f..1303eaada1 100644 --- a/src/argilla/server/app.py +++ b/src/argilla/server/app.py @@ -215,8 +215,7 @@ async def setup_elasticsearch(): def configure_app_security(app: FastAPI): - if hasattr(auth, "router"): - app.include_router(auth.router) + auth.configure_app(app) def configure_app_logging(app: FastAPI): diff --git a/src/argilla/server/contexts/accounts.py b/src/argilla/server/contexts/accounts.py index 89a31e2e10..69a13dbc81 100644 --- a/src/argilla/server/contexts/accounts.py +++ b/src/argilla/server/contexts/accounts.py @@ -145,7 +145,7 @@ async def delete_user(db: "AsyncSession", user: User) -> User: return await user.delete(db) -async def authenticate_user(db: Session, username: str, password: str): +async def authenticate_user(db: "AsyncSession", username: str, password: str): user = await get_user_by_username(db, username) if user and verify_password(password, user.password_hash): diff --git a/src/argilla/server/security/__init__.py b/src/argilla/server/security/__init__.py index b5c14a083f..d914c97247 100644 --- a/src/argilla/server/security/__init__.py +++ b/src/argilla/server/security/__init__.py @@ -12,5 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .auth_provider import DBAuthProvider +from .auth_provider.base import AuthProvider, api_key_header +from .model import User -from .factory import auth +auth = DBAuthProvider.new_instance() diff --git a/src/argilla/server/security/auth_provider/__init__.py b/src/argilla/server/security/auth_provider/__init__.py index 4954e6e6a0..8e66d8ba32 100644 --- a/src/argilla/server/security/auth_provider/__init__.py +++ b/src/argilla/server/security/auth_provider/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .local import LocalAuthProvider, create_local_auth_provider +from .db import DBAuthProvider, settings # noqa diff --git a/src/argilla/server/security/auth_provider/base.py b/src/argilla/server/security/auth_provider/base.py index 56694259b6..3ae97637bd 100644 --- a/src/argilla/server/security/auth_provider/base.py +++ b/src/argilla/server/security/auth_provider/base.py @@ -12,10 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from abc import ABCMeta, abstractmethod from typing import Optional -from fastapi import Depends +from fastapi import Depends, FastAPI, Request from fastapi.security import APIKeyHeader, SecurityScopes from argilla._constants import API_KEY_HEADER_NAME @@ -24,13 +24,24 @@ api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) -class AuthProvider: +class AuthProvider(metaclass=ABCMeta): """Base class for auth provider""" - async def get_user( + @classmethod + @abstractmethod + def new_instance(cls): + pass + + @abstractmethod + def configure_app(self, app: FastAPI): + pass + + @abstractmethod + async def get_current_user( self, security_scopes: SecurityScopes, + request: Request, api_key: Optional[str] = Depends(api_key_header), **kwargs, ) -> User: - raise NotImplementedError() + pass diff --git a/src/argilla/server/security/auth_provider/local/__init__.py b/src/argilla/server/security/auth_provider/db/__init__.py similarity index 90% rename from src/argilla/server/security/auth_provider/local/__init__.py rename to src/argilla/server/security/auth_provider/db/__init__.py index 5b088bd499..f6e8681809 100644 --- a/src/argilla/server/security/auth_provider/local/__init__.py +++ b/src/argilla/server/security/auth_provider/db/__init__.py @@ -13,4 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .provider import LocalAuthProvider, create_local_auth_provider +from .provider import DBAuthProvider +from .settings import settings diff --git a/src/argilla/server/security/auth_provider/local/provider.py b/src/argilla/server/security/auth_provider/db/provider.py similarity index 84% rename from src/argilla/server/security/auth_provider/local/provider.py rename to src/argilla/server/security/auth_provider/db/provider.py index ae094d58ae..ef151ddc60 100644 --- a/src/argilla/server/security/auth_provider/local/provider.py +++ b/src/argilla/server/security/auth_provider/db/provider.py @@ -16,7 +16,7 @@ from datetime import datetime, timedelta from typing import Optional -from fastapi import APIRouter, Depends +from fastapi import Depends, FastAPI, Request from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes from jose import JWTError, jwt from sqlalchemy.ext.asyncio import AsyncSession @@ -34,16 +34,18 @@ _oauth2_scheme = OAuth2PasswordBearer(tokenUrl=local_security.public_oauth_token_url, auto_error=False) -class LocalAuthProvider(AuthProvider): +class DBAuthProvider(AuthProvider): def __init__(self, settings: Settings): - self.router = APIRouter(tags=["security"]) self.settings = settings - # TODO: maybe it's better if we move this endpoint to apis/v0/handlers - @self.router.post( - settings.token_api_url, - response_model=Token, - operation_id="login_for_access_token", + @classmethod + def new_instance(cls) -> "DBAuthProvider": + settings = Settings() + return DBAuthProvider(settings=settings) + + def configure_app(self, app: FastAPI): + @app.post( + self.settings.token_api_url, response_model=Token, operation_id="login_for_access_token", tags=["security"] ) async def login_for_access_token( db: AsyncSession = Depends(get_async_db), @@ -52,36 +54,31 @@ async def login_for_access_token( user = await accounts.authenticate_user(db, form_data.username, form_data.password) if not user: raise UnauthorizedError() + access_token_expires = timedelta(minutes=self.settings.token_expiration_in_minutes) access_token = self._create_access_token(user.username, expires_delta=access_token_expires) + return Token(access_token=access_token) - def _create_access_token(self, username: str, expires_delta: Optional[timedelta] = None) -> str: - """ - Creates an access token + async def get_current_user( + self, + security_scopes: SecurityScopes, + request: Request, + db: AsyncSession = Depends(get_async_db), + api_key: Optional[str] = Depends(api_key_header), + token: Optional[str] = Depends(_oauth2_scheme), + ) -> User: + user = None - Parameters - ---------- - username: - The user name - expires_delta: - Token expiration + if api_key: + user = await accounts.get_user_by_api_key(db, api_key) + elif token: + user = await self.fetch_token_user(db, token) - Returns - ------- - An access token string - """ - to_encode = { - "sub": username, - } - if expires_delta: - to_encode["exp"] = datetime.utcnow() + expires_delta + if user is None: + raise UnauthorizedError() - return jwt.encode( - to_encode, - self.settings.secret_key, - algorithm=self.settings.algorithm, - ) + return user async def fetch_token_user(self, db: AsyncSession, token: str) -> Optional[User]: """ @@ -97,11 +94,7 @@ async def fetch_token_user(self, db: AsyncSession, token: str) -> Optional[User] An User instance if a valid token was provided. None otherwise """ try: - payload = jwt.decode( - token, - self.settings.secret_key, - algorithms=[self.settings.algorithm], - ) + payload = jwt.decode(token, self.settings.secret_key, algorithms=[self.settings.algorithm]) username: str = payload.get("sub") if username: user = await accounts.get_user_by_username(db, username) @@ -109,27 +102,27 @@ async def fetch_token_user(self, db: AsyncSession, token: str) -> Optional[User] except JWTError: return None - async def get_current_user( - self, - security_scopes: SecurityScopes, - db: AsyncSession = Depends(get_async_db), - api_key: Optional[str] = Depends(api_key_header), - token: Optional[str] = Depends(_oauth2_scheme), - ) -> User: - user = None - - if api_key: - user = await accounts.get_user_by_api_key(db, api_key) - elif token: - user = await self.fetch_token_user(db, token) - - if user is None: - raise UnauthorizedError() - - return user + def _create_access_token(self, username: str, expires_delta: Optional[timedelta] = None) -> str: + """ + Creates an access token + Parameters + ---------- + username: + The user name + expires_delta: + Token expiration -def create_local_auth_provider(): - settings = Settings() + Returns + ------- + An access token string + """ + to_encode = {"sub": username} + if expires_delta: + to_encode["exp"] = datetime.utcnow() + expires_delta - return LocalAuthProvider(settings=settings) + return jwt.encode( + to_encode, + self.settings.secret_key, + algorithm=self.settings.algorithm, + ) diff --git a/src/argilla/server/security/auth_provider/local/settings.py b/src/argilla/server/security/auth_provider/db/settings.py similarity index 100% rename from src/argilla/server/security/auth_provider/local/settings.py rename to src/argilla/server/security/auth_provider/db/settings.py diff --git a/src/argilla/server/security/factory.py b/src/argilla/server/server.py similarity index 82% rename from src/argilla/server/security/factory.py rename to src/argilla/server/server.py index 6724e5194b..55be41799b 100644 --- a/src/argilla/server/security/factory.py +++ b/src/argilla/server/server.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2021-present, the Recognai S.L. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,7 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from argilla.server.security.auth_provider import create_local_auth_provider - -auth = create_local_auth_provider() diff --git a/tests/integration/client/feedback/dataset/remote/test_dataset.py b/tests/integration/client/feedback/dataset/remote/test_dataset.py index df06a84afa..c697825ecd 100644 --- a/tests/integration/client/feedback/dataset/remote/test_dataset.py +++ b/tests/integration/client/feedback/dataset/remote/test_dataset.py @@ -49,9 +49,9 @@ from argilla.client.sdk.users.models import UserRole from argilla.client.workspaces import Workspace from argilla.server.models import User as ServerUser +from argilla.server.settings import settings from sqlalchemy.ext.asyncio import AsyncSession -from argilla.server.settings import settings from tests.factories import ( DatasetFactory, RecordFactory, diff --git a/tests/unit/cli/server/database/users/test_migrate.py b/tests/unit/cli/server/database/users/test_migrate.py index 6c7b1388c9..b475bc3212 100644 --- a/tests/unit/cli/server/database/users/test_migrate.py +++ b/tests/unit/cli/server/database/users/test_migrate.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING from argilla.server.models import User, UserRole, Workspace, WorkspaceUser -from argilla.server.security.auth_provider.local.settings import settings +from argilla.server.security.auth_provider.db.settings import settings from click.testing import CliRunner from typer import Typer diff --git a/tests/unit/server/security/test_provider.py b/tests/unit/server/security/test_provider.py index 5f4bcaa1a1..7cd7a589bb 100644 --- a/tests/unit/server/security/test_provider.py +++ b/tests/unit/server/security/test_provider.py @@ -16,35 +16,39 @@ import pytest from argilla._constants import DEFAULT_API_KEY -from argilla.server.security.auth_provider.local.provider import create_local_auth_provider +from argilla.server.security.auth_provider.db import DBAuthProvider from fastapi.security import SecurityScopes if TYPE_CHECKING: from argilla.server.models import User from sqlalchemy.ext.asyncio import AsyncSession -localAuth = create_local_auth_provider() -security_Scopes = SecurityScopes +db_auth = DBAuthProvider.new_instance() +security_Scopes = SecurityScopes() # Tests for function get_user via token and api key @pytest.mark.asyncio async def test_get_user_via_token(db: "AsyncSession", argilla_user: "User"): - access_token = localAuth._create_access_token(username=argilla_user.username) + access_token = db_auth._create_access_token(username=argilla_user.username) - user = await localAuth.get_current_user(security_scopes=security_Scopes, db=db, token=access_token, api_key=None) + user = await db_auth.get_current_user( + security_scopes=security_Scopes, request=None, db=db, token=access_token, api_key=None + ) assert user.username == "argilla" @pytest.mark.asyncio async def test_get_user_via_api_key(db: "AsyncSession", argilla_user: "User"): - user = await localAuth.get_current_user(security_scopes=security_Scopes, db=db, api_key=DEFAULT_API_KEY, token=None) + user = await db_auth.get_current_user( + security_scopes=security_Scopes, request=None, db=db, api_key=DEFAULT_API_KEY, token=None + ) assert user.username == "argilla" # Test for function fetch token @pytest.mark.asyncio async def test_fetch_token_user(db: "AsyncSession", argilla_user: "User"): - access_token = localAuth._create_access_token(username=argilla_user.username) - user = await localAuth.fetch_token_user(db=db, token=access_token) + access_token = db_auth._create_access_token(username=argilla_user.username) + user = await db_auth.fetch_token_user(db=db, token=access_token) assert user.username == "argilla"