diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 695c5e2..ddb3791 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -1,26 +1,41 @@ """FastAPI Users database adapter for SQLModel.""" + import uuid -from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type, _ProtocolMeta from fastapi_users.db.base import BaseUserDatabase -from fastapi_users.models import ID, OAP, UP -from pydantic import UUID4, EmailStr +from fastapi_users.models import ( + ID, + UP, + OAuthAccountProtocol, + UserProtocol, +) +from pydantic import UUID4, ConfigDict, EmailStr +from pydantic.version import VERSION as PYDANTIC_VERSION from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from sqlmodel import Field, Session, SQLModel, func, select +from sqlmodel import AutoString, Field, Session, SQLModel, func, select +from sqlmodel.main import SQLModelMetaclass __version__ = "0.3.0" +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + + +class SQLModelProtocolMetaclass(SQLModelMetaclass, _ProtocolMeta): + pass -class SQLModelBaseUserDB(SQLModel): - __tablename__ = "user" +class SQLModelBaseUserDB(SQLModel, UserProtocol, metaclass=SQLModelProtocolMetaclass): + __tablename__ = "user" # type: ignore id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True, nullable=False) if TYPE_CHECKING: # pragma: no cover email: str else: email: EmailStr = Field( - sa_column_kwargs={"unique": True, "index": True}, nullable=False + sa_column_kwargs={"unique": True, "index": True}, + nullable=False, + sa_type=AutoString, ) hashed_password: str @@ -28,12 +43,18 @@ class SQLModelBaseUserDB(SQLModel): is_superuser: bool = Field(False, nullable=False) is_verified: bool = Field(False, nullable=False) - class Config: - orm_mode = True + if PYDANTIC_V2: # pragma: no cover + model_config = ConfigDict(from_attributes=True) # type: ignore + else: # pragma: no cover + class Config: + orm_mode = True -class SQLModelBaseOAuthAccount(SQLModel): - __tablename__ = "oauthaccount" + +class SQLModelBaseOAuthAccount( + SQLModel, OAuthAccountProtocol, metaclass=SQLModelProtocolMetaclass +): + __tablename__ = "oauthaccount" # type: ignore id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True) user_id: UUID4 = Field(foreign_key="user.id", nullable=False) @@ -44,8 +65,13 @@ class SQLModelBaseOAuthAccount(SQLModel): account_id: str = Field(index=True, nullable=False) account_email: str = Field(nullable=False) - class Config: - orm_mode = True + if PYDANTIC_V2: + # pragma: no cover + model_config = ConfigDict(from_attributes=True) # type: ignore + else: + + class Config: # pragma: no cover + orm_mode = True class SQLModelUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): @@ -130,7 +156,7 @@ async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP: return user async def update_oauth_account( - self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any] + self, user: UP, oauth_account: OAuthAccountProtocol, update_dict: Dict[str, Any] ) -> UP: if self.oauth_account_model is None: raise NotImplementedError() @@ -230,7 +256,7 @@ async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP: return user async def update_oauth_account( - self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any] + self, user: UP, oauth_account: OAuthAccountProtocol, update_dict: Dict[str, Any] ) -> UP: if self.oauth_account_model is None: raise NotImplementedError() diff --git a/fastapi_users_db_sqlmodel/access_token.py b/fastapi_users_db_sqlmodel/access_token.py index 8a4519e..db342de 100644 --- a/fastapi_users_db_sqlmodel/access_token.py +++ b/fastapi_users_db_sqlmodel/access_token.py @@ -1,34 +1,70 @@ from datetime import datetime from typing import Any, Dict, Generic, Optional, Type -from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase -from pydantic import UUID4 -from sqlalchemy import Column, types +from fastapi_users.authentication.strategy.db import ( + AP, + APE, + AccessRefreshTokenDatabase, + AccessTokenDatabase, +) +from fastapi_users.authentication.strategy.db.adapter import BaseAccessTokenDatabase +from fastapi_users.authentication.strategy.db.models import ( + AccessRefreshTokenProtocol, + AccessTokenProtocol, +) +from pydantic import UUID4, ConfigDict +from pydantic.version import VERSION as PYDANTIC_VERSION +from sqlalchemy import types from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import Field, Session, SQLModel, select from fastapi_users_db_sqlmodel.generics import TIMESTAMPAware, now_utc +from . import SQLModelProtocolMetaclass -class SQLModelBaseAccessToken(SQLModel): - __tablename__ = "accesstoken" +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + + +class SQLModelBaseAccessToken( + SQLModel, AccessTokenProtocol, metaclass=SQLModelProtocolMetaclass +): + __tablename__ = "accesstoken" # type: ignore token: str = Field( - sa_column=Column("token", types.String(length=43), primary_key=True) + sa_type=types.String(length=43), # type: ignore + primary_key=True, ) created_at: datetime = Field( default_factory=now_utc, - sa_column=Column( - "created_at", TIMESTAMPAware(timezone=True), nullable=False, index=True - ), + sa_type=TIMESTAMPAware(timezone=True), # type: ignore + nullable=False, + index=True, ) user_id: UUID4 = Field(foreign_key="user.id", nullable=False) - class Config: - orm_mode = True + if PYDANTIC_V2: # pragma: no cover + model_config = ConfigDict(from_attributes=True) # type: ignore + else: # pragma: no cover + + class Config: + orm_mode = True + + +class SQLModelBaseAccessRefreshToken( + SQLModelBaseAccessToken, + AccessRefreshTokenProtocol, + metaclass=SQLModelProtocolMetaclass, +): + __tablename__ = "accessrefreshtoken" + + refresh_token: str = Field( + sa_type=types.String(length=43), # type: ignore + unique=True, + index=True, + ) -class SQLModelAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]): +class BaseSQLModelAccessTokenDatabase(Generic[AP], BaseAccessTokenDatabase[str, AP]): """ Access token database adapter for SQLModel. @@ -75,7 +111,47 @@ async def delete(self, access_token: AP) -> None: self.session.commit() -class SQLModelAccessTokenDatabaseAsync(Generic[AP], AccessTokenDatabase[AP]): +class SQLModelAccessTokenDatabase( + Generic[AP], BaseSQLModelAccessTokenDatabase[AP], AccessTokenDatabase[AP] +): + """ + Access token database adapter for SQLModel. + + :param session: SQLAlchemy session. + :param access_token_model: SQLModel access token model. + """ + + +class SQLModelAccessRefreshTokenDatabase( + Generic[APE], BaseSQLModelAccessTokenDatabase[APE], AccessRefreshTokenDatabase[APE] +): + """ + Access token database adapter for SQLModel. + + :param session: SQLAlchemy session. + :param access_token_model: SQLModel access refresh token model. + """ + + async def get_by_refresh_token( + self, refresh_token: str, max_age: Optional[datetime] = None + ) -> Optional[APE]: + statement = select(self.access_token_model).where( # type: ignore + self.access_token_model.refresh_token == refresh_token + ) + if max_age is not None: + statement = statement.where(self.access_token_model.created_at >= max_age) + + results = self.session.exec(statement) + access_token = results.first() + if access_token is None: + return None + + return access_token + + +class BaseSQLModelAccessTokenDatabaseAsync( + Generic[AP], BaseAccessTokenDatabase[str, AP] +): """ Access token database adapter for SQLModel working purely asynchronously. @@ -120,3 +196,31 @@ async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: async def delete(self, access_token: AP) -> None: await self.session.delete(access_token) await self.session.commit() + + +class SQLModelAccessTokenDatabaseAsync( + BaseSQLModelAccessTokenDatabaseAsync[AP], AccessTokenDatabase[AP], Generic[AP] +): + pass + + +class SQLModelAccessRefreshTokenDatabaseAsync( + BaseSQLModelAccessTokenDatabaseAsync[APE], + AccessRefreshTokenDatabase[APE], + Generic[APE], +): + async def get_by_refresh_token( + self, refresh_token: str, max_age: Optional[datetime] = None + ) -> Optional[APE]: + statement = select(self.access_token_model).where( # type: ignore + self.access_token_model.refresh_token == refresh_token + ) + if max_age is not None: + statement = statement.where(self.access_token_model.created_at >= max_age) + + results = await self.session.execute(statement) + access_token = results.first() + if access_token is None: + return None + + return access_token[0] diff --git a/fastapi_users_db_sqlmodel/generics.py b/fastapi_users_db_sqlmodel/generics.py index bfe2fda..630da53 100644 --- a/fastapi_users_db_sqlmodel/generics.py +++ b/fastapi_users_db_sqlmodel/generics.py @@ -1,4 +1,5 @@ from datetime import datetime, timezone +from typing import Optional from sqlalchemy import TIMESTAMP, TypeDecorator @@ -18,7 +19,8 @@ class TIMESTAMPAware(TypeDecorator): # pragma: no cover impl = TIMESTAMP cache_ok = True - def process_result_value(self, value: datetime, dialect): + def process_result_value(self, value: Optional[datetime], dialect): if dialect.name != "postgresql": - return value.replace(tzinfo=timezone.utc) + if value is not None: + return value.replace(tzinfo=timezone.utc) return value diff --git a/pyproject.toml b/pyproject.toml index 971fac3..47745ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ classifiers = [ ] requires-python = ">=3.7" dependencies = [ - "fastapi-users >= 10.0.2", + "fastapi-users @ git+https://github.com/Ae-Mc/fastapi-users@add-refresh-token", "greenlet", "sqlmodel", ] diff --git a/tests/test_access_token.py b/tests/test_access_token.py index ee93e04..df2e80b 100644 --- a/tests/test_access_token.py +++ b/tests/test_access_token.py @@ -12,8 +12,11 @@ from fastapi_users_db_sqlmodel import SQLModelUserDatabase, SQLModelUserDatabaseAsync from fastapi_users_db_sqlmodel.access_token import ( + SQLModelAccessRefreshTokenDatabase, + SQLModelAccessRefreshTokenDatabaseAsync, SQLModelAccessTokenDatabase, SQLModelAccessTokenDatabaseAsync, + SQLModelBaseAccessRefreshToken, SQLModelBaseAccessToken, ) from tests.conftest import User @@ -23,6 +26,10 @@ class AccessToken(SQLModelBaseAccessToken, table=True): pass +class AccessRefreshToken(SQLModelBaseAccessRefreshToken, table=True): + pass + + @pytest.fixture def user_id() -> UUID4: return uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec") @@ -38,10 +45,10 @@ async def init_sync_session(url: str) -> AsyncGenerator[Session, None]: async def init_async_session(url: str) -> AsyncGenerator[AsyncSession, None]: engine = create_async_engine(url, connect_args={"check_same_thread": False}) - make_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + make_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) # type: ignore async with engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) - async with make_session() as session: + async with make_session() as session: # type: ignore yield session await conn.run_sync(SQLModel.metadata.drop_all) @@ -82,6 +89,42 @@ async def sqlmodel_access_token_db( yield access_token_database_class(session, AccessToken) +@pytest_asyncio.fixture( + params=[ + ( + init_sync_session, + "sqlite:///./test-sqlmodel-access-refresh-token.db", + SQLModelAccessRefreshTokenDatabase, + SQLModelUserDatabase, + ), + ( + init_async_session, + "sqlite+aiosqlite:///./test-sqlmodel-access-refresh-token.db", + SQLModelAccessRefreshTokenDatabaseAsync, + SQLModelUserDatabaseAsync, + ), + ], + ids=["sync_refresh", "async_refresh"], +) +async def sqlmodel_access_refresh_token_db( + request, user_id: UUID4 +) -> AsyncGenerator[SQLModelAccessRefreshTokenDatabase, None]: + create_session = request.param[0] + database_url = request.param[1] + access_token_database_class = request.param[2] + user_database_class = request.param[3] + async for session in create_session(database_url): + user_db = user_database_class(session, User) + await user_db.create( + { + "id": user_id, + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } + ) + yield access_token_database_class(session, AccessRefreshToken) + + @pytest.mark.asyncio async def test_queries( sqlmodel_access_token_db: SQLModelAccessTokenDatabase[AccessToken], @@ -144,3 +187,104 @@ async def test_insert_existing_token( with pytest.raises(exc.IntegrityError): await sqlmodel_access_token_db.create(access_token_create) + + +@pytest.mark.asyncio +async def test_refresh_queries( + sqlmodel_access_refresh_token_db: SQLModelAccessRefreshTokenDatabase[ + AccessRefreshToken + ], + user_id: UUID4, +): + access_token_create = { + "token": "TOKEN", + "refresh_token": "REFRESH", + "user_id": user_id, + } + + # Create + access_token = await sqlmodel_access_refresh_token_db.create(access_token_create) + assert access_token.token == "TOKEN" + assert access_token.refresh_token == "REFRESH" + assert access_token.user_id == user_id + + # Update + update_dict = {"created_at": datetime.now(timezone.utc)} + updated_access_token = await sqlmodel_access_refresh_token_db.update( + access_token, update_dict + ) + assert updated_access_token.created_at.replace(microsecond=0) == update_dict[ + "created_at" + ].replace(microsecond=0) + + # Get by refresh token + access_token_by_token = await sqlmodel_access_refresh_token_db.get_by_refresh_token( + access_token.refresh_token + ) + assert access_token_by_token is not None + + # Get by refresh token expired + access_token_by_token = await sqlmodel_access_refresh_token_db.get_by_refresh_token( + access_token.refresh_token, + max_age=datetime.now(timezone.utc) + timedelta(hours=1), + ) + assert access_token_by_token is None + + # Get by refresh token not expired + access_token_by_token = await sqlmodel_access_refresh_token_db.get_by_refresh_token( + access_token.refresh_token, + max_age=datetime.now(timezone.utc) - timedelta(hours=1), + ) + + # Get by refresh token unknown + access_token_by_token = await sqlmodel_access_refresh_token_db.get_by_refresh_token( + "NOT_EXISTING_TOKEN" + ) + assert access_token_by_token is None + + # Get by token + access_token_by_token = await sqlmodel_access_refresh_token_db.get_by_token( + access_token.token + ) + assert access_token_by_token is not None + + # Get by token expired + access_token_by_token = await sqlmodel_access_refresh_token_db.get_by_token( + access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) + ) + assert access_token_by_token is None + + # Get by token not expired + access_token_by_token = await sqlmodel_access_refresh_token_db.get_by_token( + access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) + ) + assert access_token_by_token is not None + + # Get by token unknown + access_token_by_token = await sqlmodel_access_refresh_token_db.get_by_token( + "NOT_EXISTING_TOKEN" + ) + assert access_token_by_token is None + + # Delete token + await sqlmodel_access_refresh_token_db.delete(access_token) + deleted_access_token = await sqlmodel_access_refresh_token_db.get_by_token( + access_token.token + ) + assert deleted_access_token is None + + +@pytest.mark.asyncio +async def test_insert_existing_token_refresh( + sqlmodel_access_token_db: SQLModelAccessRefreshTokenDatabase[AccessRefreshToken], + user_id: UUID4, +): + access_token_create = { + "token": "TOKEN", + "refresh_token": "REFRESH", + "user_id": user_id, + } + await sqlmodel_access_token_db.create(access_token_create) + + with pytest.raises(exc.IntegrityError): + await sqlmodel_access_token_db.create(access_token_create)