Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add refresh token #23

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 41 additions & 15 deletions fastapi_users_db_sqlmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,60 @@
"""FastAPI Users database adapter for SQLModel."""

import uuid
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type, _ProtocolMeta

from fastapi_users.db.base import BaseUserDatabase
from fastapi_users.models import ID, OAP, UP
from pydantic import UUID4, EmailStr
from fastapi_users.models import (
ID,
UP,
OAuthAccountProtocol,
UserProtocol,
)
from pydantic import UUID4, ConfigDict, EmailStr
from pydantic.version import VERSION as PYDANTIC_VERSION
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlmodel import Field, Session, SQLModel, func, select
from sqlmodel import AutoString, Field, Session, SQLModel, func, select
from sqlmodel.main import SQLModelMetaclass

__version__ = "0.3.0"
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")


class SQLModelProtocolMetaclass(SQLModelMetaclass, _ProtocolMeta):
pass


class SQLModelBaseUserDB(SQLModel):
__tablename__ = "user"
class SQLModelBaseUserDB(SQLModel, UserProtocol, metaclass=SQLModelProtocolMetaclass):
__tablename__ = "user" # type: ignore

id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True, nullable=False)
if TYPE_CHECKING: # pragma: no cover
email: str
else:
email: EmailStr = Field(
sa_column_kwargs={"unique": True, "index": True}, nullable=False
sa_column_kwargs={"unique": True, "index": True},
nullable=False,
sa_type=AutoString,
)
hashed_password: str

is_active: bool = Field(True, nullable=False)
is_superuser: bool = Field(False, nullable=False)
is_verified: bool = Field(False, nullable=False)

class Config:
orm_mode = True
if PYDANTIC_V2: # pragma: no cover
model_config = ConfigDict(from_attributes=True) # type: ignore
else: # pragma: no cover

class Config:
orm_mode = True

class SQLModelBaseOAuthAccount(SQLModel):
__tablename__ = "oauthaccount"

class SQLModelBaseOAuthAccount(
SQLModel, OAuthAccountProtocol, metaclass=SQLModelProtocolMetaclass
):
__tablename__ = "oauthaccount" # type: ignore

id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True)
user_id: UUID4 = Field(foreign_key="user.id", nullable=False)
Expand All @@ -44,8 +65,13 @@ class SQLModelBaseOAuthAccount(SQLModel):
account_id: str = Field(index=True, nullable=False)
account_email: str = Field(nullable=False)

class Config:
orm_mode = True
if PYDANTIC_V2:
# pragma: no cover
model_config = ConfigDict(from_attributes=True) # type: ignore
else:

class Config: # pragma: no cover
orm_mode = True


class SQLModelUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]):
Expand Down Expand Up @@ -130,7 +156,7 @@ async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP:
return user

async def update_oauth_account(
self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any]
self, user: UP, oauth_account: OAuthAccountProtocol, update_dict: Dict[str, Any]
) -> UP:
if self.oauth_account_model is None:
raise NotImplementedError()
Expand Down Expand Up @@ -230,7 +256,7 @@ async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP:
return user

async def update_oauth_account(
self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any]
self, user: UP, oauth_account: OAuthAccountProtocol, update_dict: Dict[str, Any]
) -> UP:
if self.oauth_account_model is None:
raise NotImplementedError()
Expand Down
130 changes: 117 additions & 13 deletions fastapi_users_db_sqlmodel/access_token.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,70 @@
from datetime import datetime
from typing import Any, Dict, Generic, Optional, Type

from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase
from pydantic import UUID4
from sqlalchemy import Column, types
from fastapi_users.authentication.strategy.db import (
AP,
APE,
AccessRefreshTokenDatabase,
AccessTokenDatabase,
)
from fastapi_users.authentication.strategy.db.adapter import BaseAccessTokenDatabase
from fastapi_users.authentication.strategy.db.models import (
AccessRefreshTokenProtocol,
AccessTokenProtocol,
)
from pydantic import UUID4, ConfigDict
from pydantic.version import VERSION as PYDANTIC_VERSION
from sqlalchemy import types
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import Field, Session, SQLModel, select

from fastapi_users_db_sqlmodel.generics import TIMESTAMPAware, now_utc

from . import SQLModelProtocolMetaclass

class SQLModelBaseAccessToken(SQLModel):
__tablename__ = "accesstoken"
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")


class SQLModelBaseAccessToken(
SQLModel, AccessTokenProtocol, metaclass=SQLModelProtocolMetaclass
):
__tablename__ = "accesstoken" # type: ignore

token: str = Field(
sa_column=Column("token", types.String(length=43), primary_key=True)
sa_type=types.String(length=43), # type: ignore
primary_key=True,
)
created_at: datetime = Field(
default_factory=now_utc,
sa_column=Column(
"created_at", TIMESTAMPAware(timezone=True), nullable=False, index=True
),
sa_type=TIMESTAMPAware(timezone=True), # type: ignore
nullable=False,
index=True,
)
user_id: UUID4 = Field(foreign_key="user.id", nullable=False)

class Config:
orm_mode = True
if PYDANTIC_V2: # pragma: no cover
model_config = ConfigDict(from_attributes=True) # type: ignore
else: # pragma: no cover

class Config:
orm_mode = True


class SQLModelBaseAccessRefreshToken(
SQLModelBaseAccessToken,
AccessRefreshTokenProtocol,
metaclass=SQLModelProtocolMetaclass,
):
__tablename__ = "accessrefreshtoken"

refresh_token: str = Field(
sa_type=types.String(length=43), # type: ignore
unique=True,
index=True,
)


class SQLModelAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]):
class BaseSQLModelAccessTokenDatabase(Generic[AP], BaseAccessTokenDatabase[str, AP]):
"""
Access token database adapter for SQLModel.

Expand Down Expand Up @@ -75,7 +111,47 @@ async def delete(self, access_token: AP) -> None:
self.session.commit()


class SQLModelAccessTokenDatabaseAsync(Generic[AP], AccessTokenDatabase[AP]):
class SQLModelAccessTokenDatabase(
Generic[AP], BaseSQLModelAccessTokenDatabase[AP], AccessTokenDatabase[AP]
):
"""
Access token database adapter for SQLModel.

:param session: SQLAlchemy session.
:param access_token_model: SQLModel access token model.
"""


class SQLModelAccessRefreshTokenDatabase(
Generic[APE], BaseSQLModelAccessTokenDatabase[APE], AccessRefreshTokenDatabase[APE]
):
"""
Access token database adapter for SQLModel.

:param session: SQLAlchemy session.
:param access_token_model: SQLModel access refresh token model.
"""

async def get_by_refresh_token(
self, refresh_token: str, max_age: Optional[datetime] = None
) -> Optional[APE]:
statement = select(self.access_token_model).where( # type: ignore
self.access_token_model.refresh_token == refresh_token
)
if max_age is not None:
statement = statement.where(self.access_token_model.created_at >= max_age)

results = self.session.exec(statement)
access_token = results.first()
if access_token is None:
return None

return access_token


class BaseSQLModelAccessTokenDatabaseAsync(
Generic[AP], BaseAccessTokenDatabase[str, AP]
):
"""
Access token database adapter for SQLModel working purely asynchronously.

Expand Down Expand Up @@ -120,3 +196,31 @@ async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
async def delete(self, access_token: AP) -> None:
await self.session.delete(access_token)
await self.session.commit()


class SQLModelAccessTokenDatabaseAsync(
BaseSQLModelAccessTokenDatabaseAsync[AP], AccessTokenDatabase[AP], Generic[AP]
):
pass


class SQLModelAccessRefreshTokenDatabaseAsync(
BaseSQLModelAccessTokenDatabaseAsync[APE],
AccessRefreshTokenDatabase[APE],
Generic[APE],
):
async def get_by_refresh_token(
self, refresh_token: str, max_age: Optional[datetime] = None
) -> Optional[APE]:
statement = select(self.access_token_model).where( # type: ignore
self.access_token_model.refresh_token == refresh_token
)
if max_age is not None:
statement = statement.where(self.access_token_model.created_at >= max_age)

results = await self.session.execute(statement)
access_token = results.first()
if access_token is None:
return None

return access_token[0]
6 changes: 4 additions & 2 deletions fastapi_users_db_sqlmodel/generics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timezone
from typing import Optional

from sqlalchemy import TIMESTAMP, TypeDecorator

Expand All @@ -18,7 +19,8 @@ class TIMESTAMPAware(TypeDecorator): # pragma: no cover
impl = TIMESTAMP
cache_ok = True

def process_result_value(self, value: datetime, dialect):
def process_result_value(self, value: Optional[datetime], dialect):
if dialect.name != "postgresql":
return value.replace(tzinfo=timezone.utc)
if value is not None:
return value.replace(tzinfo=timezone.utc)
return value
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ classifiers = [
]
requires-python = ">=3.7"
dependencies = [
"fastapi-users >= 10.0.2",
"fastapi-users @ git+https://github.com/Ae-Mc/fastapi-users@add-refresh-token",
"greenlet",
"sqlmodel",
]
Expand Down
Loading