Skip to content

Commit

Permalink
Get and revoke third-party tokens
Browse files Browse the repository at this point in the history
Add: GET /auth/token/thirdparty to list third-party tokens

Add: DELETE /auth/token/{token_reference} to revoke token by reference

Rename: GET /auth/info to GET /auth/token/info

Add: Last usage time field "used" to AuthToken
  • Loading branch information
kuyugama committed Aug 12, 2024
1 parent e84a350 commit 8b2c77d
Show file tree
Hide file tree
Showing 15 changed files with 335 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Add "used" (last usage time) field to AuthToken
Revision ID: b38d06d1364e
Revises: a4d5b6174fb1
Create Date: 2024-08-12 10:36:22.543968
"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "b38d06d1364e"
down_revision = "a4d5b6174fb1"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"service_auth_tokens",
sa.Column("used", sa.DateTime(), nullable=True),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("service_auth_tokens", "used")
# ### end Alembic commands ###
19 changes: 19 additions & 0 deletions app/auth/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_user_by_activation,
get_user_by_reset,
get_oauth_by_id,
get_auth_token,
)

from .schemas import (
Expand Down Expand Up @@ -242,3 +243,21 @@ async def validate_auth_token_request(
raise Abort("auth", "invalid-client-credentials")

return request


async def validate_auth_token(
token_reference: uuid.UUID,
session: AsyncSession = Depends(get_session),
user: User = Depends(auth_required()),
):
now = utcnow()
if not (token := await get_auth_token(session, token_reference)):
raise Abort("auth", "invalid-token")

if now > token.expiration:
raise Abort("auth", "token-expired")

if token.user_id != user.id:
raise Abort("auth", "not-token-owner")

return token
55 changes: 52 additions & 3 deletions app/auth/router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from app.dependencies import check_captcha, auth_required, auth_token_required
from app.models import User, UserOAuth, AuthToken, Client, AuthTokenRequest
from app.utils import pagination, pagination_dict, utcnow
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import APIRouter, Depends
from app.schemas import UserResponse
Expand All @@ -9,6 +9,13 @@
from . import service
from . import oauth

from app.dependencies import (
auth_token_required,
check_captcha,
auth_required,
get_page,
get_size,
)
from app.service import (
create_activation_token,
create_email,
Expand All @@ -20,6 +27,7 @@
validate_activation_resend,
validate_password_confirm,
validate_password_reset,
validate_auth_token,
validate_activation,
validate_provider,
validate_signup,
Expand All @@ -31,9 +39,10 @@
)

from .schemas import (
AuthTokenInfoPaginationResponse,
AuthTokenInfoResponse,
TokenRequestResponse,
ProviderUrlResponse,
AuthInfoResponse,
TokenResponse,
SignupArgs,
)
Expand Down Expand Up @@ -193,7 +202,9 @@ async def oauth_token(


@router.get(
"/info", summary="Get authorization info", response_model=AuthInfoResponse
"/token/info",
summary="Get token info",
response_model=AuthTokenInfoResponse,
)
async def auth_info(token: AuthToken = Depends(auth_token_required)):
return token
Expand Down Expand Up @@ -230,3 +241,41 @@ async def third_party_auth_token(
{"scope": token_request.scope},
)
return await service.create_auth_token_from_request(session, token_request)


@router.get(
"/token/thirdparty",
summary="List third-party auth tokens",
response_model=AuthTokenInfoPaginationResponse,
)
async def third_party_auth_tokens(
session: AsyncSession = Depends(get_session),
user: User = Depends(auth_required(forbid_thirdparty=True)),
page: int = Depends(get_page),
size: int = Depends(get_size),
):
limit, offset = pagination(page, size)
now = utcnow()

total = await service.count_user_thirdparty_auth_tokens(session, user, now)
tokens = await service.list_user_thirdparty_auth_tokens(
session, user, offset, limit, now
)

return {
"pagination": pagination_dict(total, page, limit),
"list": tokens.all(),
}


@router.delete(
"/token/{token_reference}",
summary="Revoke auth token",
response_model=AuthTokenInfoResponse,
dependencies=[Depends(auth_required(forbid_thirdparty=True))],
)
async def revoke_token(
token: AuthToken = Depends(validate_auth_token),
session: AsyncSession = Depends(get_session),
):
return await service.revoke_auth_token(session, token)
13 changes: 10 additions & 3 deletions app/auth/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from app import constants
from app.schemas import datetime_pd, ClientResponse
from app.schemas import datetime_pd, ClientResponse, PaginationResponse
from pydantic import Field
from app import constants
import uuid

from app.schemas import (
Expand Down Expand Up @@ -42,13 +42,20 @@ class TokenResponse(CustomModel):
)


class AuthInfoResponse(CustomModel):
class AuthTokenInfoResponse(CustomModel):
reference: str = Field(examples=["c773d0bf-1c42-4c18-aec8-1bdd8cb0a434"])
created: datetime_pd = Field(examples=[1686088809])
client: ClientResponse | None = Field(
description="Information about logged by third-party client"
)
scope: list[str]
expiration: datetime_pd = Field(examples=[1686088809])
used: datetime_pd | None = Field(examples=[1686088809, None])


class AuthTokenInfoPaginationResponse(CustomModel):
list: list[AuthTokenInfoResponse]
pagination: PaginationResponse


class TokenRequestResponse(CustomModel):
Expand Down
54 changes: 52 additions & 2 deletions app/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from starlette.datastructures import URL

from app.models import User, AuthToken, UserOAuth, AuthTokenRequest, Client
from sqlalchemy import select, func, ScalarResult
from app.utils import hashpwd, new_token, utcnow
from sqlalchemy.ext.asyncio import AsyncSession
from app.service import get_user_by_username
from datetime import timedelta, datetime
from sqlalchemy.orm import selectinload
from .schemas import SignupArgs
from datetime import timedelta
from sqlalchemy import select
from app import constants
import secrets

Expand Down Expand Up @@ -155,6 +155,7 @@ async def create_auth_token(session: AsyncSession, user: User) -> AuthToken:
"expiration": now + timedelta(minutes=30),
"secret": new_token(),
"created": now,
"used": now,
"user": user,
}
)
Expand Down Expand Up @@ -264,3 +265,52 @@ async def create_auth_token_from_request(
await session.commit()

return token


async def count_user_thirdparty_auth_tokens(
session: AsyncSession, user: User, now: datetime
) -> int:
return await session.scalar(
select(func.count(AuthToken.id)).filter(
AuthToken.user_id == user.id,
AuthToken.client_id.is_not(None),
AuthToken.expiration >= now,
)
)


async def list_user_thirdparty_auth_tokens(
session: AsyncSession,
user: User,
offset: int,
limit: int,
now: datetime,
) -> ScalarResult[AuthToken]:
return await session.scalars(
select(AuthToken)
.options(selectinload(AuthToken.client))
.filter(
AuthToken.user_id == user.id,
AuthToken.client_id.is_not(None),
AuthToken.expiration >= now,
)
.offset(offset)
.limit(limit)
)


async def get_auth_token(
session: AsyncSession, reference: str | uuid.UUID
) -> AuthToken:
return await session.scalar(
select(AuthToken)
.filter(AuthToken.id == reference)
.options(selectinload(AuthToken.client), selectinload(AuthToken.user))
)


async def revoke_auth_token(session: AsyncSession, token: AuthToken):
await session.delete(token)
await session.commit()

return token
14 changes: 9 additions & 5 deletions app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ async def _auth_token_or_abort(
session: AsyncSession = Depends(get_session),
token: str | None = Depends(get_request_auth_token),
) -> Abort | AuthToken:
now = utcnow()

if not token:
return Abort("auth", "missing-token")

Expand All @@ -80,8 +82,13 @@ async def _auth_token_or_abort(
if token.user.banned:
return Abort("auth", "banned")

return token
if now > token.expiration:
return Abort("auth", "token-expired")

token.used = now
await session.commit()

return token

async def auth_token_required(
token: AuthToken | Abort = Depends(_auth_token_or_abort),
Expand All @@ -93,7 +100,7 @@ async def auth_token_required(


async def auth_token_optional(
token: AuthToken | Abort = Depends(_auth_token_or_abort),
token: AuthToken | Abort = Depends(_auth_token_or_abort)
) -> AuthToken | None:
if isinstance(token, Abort):
return None
Expand Down Expand Up @@ -138,9 +145,6 @@ async def auth(

now = utcnow()

if now > token.expiration:
raise Abort("auth", "token-expired")

# Check requested permissions here
if not utils.check_user_permissions(token.user, permissions):
raise Abort("permission", "denied")
Expand Down
1 change: 1 addition & 0 deletions app/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ErrorResponse(CustomModel):
"token-request-expired": ["Token request has expired", 400],
"activation-invalid": ["Activation token is invalid", 400],
"invalid-token-request": ["Invalid token request", 400],
"not-token-owner": ["User is not token owner", 400],
"oauth-code-required": ["OAuth code required", 400],
"invalid-provider": ["Invalid OAuth provider", 400],
"username-taken": ["Username already taken", 400],
Expand Down
2 changes: 2 additions & 0 deletions app/models/auth/auth_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class AuthToken(Base):

secret: Mapped[str] = mapped_column(String(64), unique=True, index=True)
expiration: Mapped[datetime]

created: Mapped[datetime]
used: Mapped[datetime] = mapped_column(nullable=True)

user_id = mapped_column(ForeignKey("service_users.id"))

Expand Down
8 changes: 4 additions & 4 deletions tests/auth/test_auth_thirdparty.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from tests.client_requests import (
request_auth_token_request,
request_auth_token_info,
request_client_create,
request_auth_token,
request_auth_info,
)


Expand Down Expand Up @@ -43,9 +43,9 @@ async def test_auth_thirdparty(client, test_token):

thirdparty_token = response.json()["secret"]

response = await request_auth_info(client, thirdparty_token)
response = await request_auth_token_info(client, thirdparty_token)
assert response.status_code == status.HTTP_200_OK

auth_info = response.json()
token_info = response.json()

assert auth_info["client"]["reference"] == client_reference
assert token_info["client"]["reference"] == client_reference
30 changes: 30 additions & 0 deletions tests/auth/test_list_thirdparty_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from starlette import status

from tests.client_requests import (
request_list_thirdparty_tokens,
request_auth_token_info,
)


async def test_list_no_thirdparty_tokens(client, test_token):
response = await request_list_thirdparty_tokens(client, test_token)
assert response.status_code == status.HTTP_200_OK

assert response.json()["list"] == []
assert response.json()["pagination"] == {"total": 0, "page": 1, "pages": 0}


async def test_list_thirdparty_tokens(
client, test_token, test_thirdparty_token
):
response = await request_auth_token_info(client, test_thirdparty_token)
assert response.status_code == status.HTTP_200_OK

token_info = response.json()

response = await request_list_thirdparty_tokens(client, test_token)
assert response.status_code == status.HTTP_200_OK

assert response.json()["pagination"] == {"total": 1, "page": 1, "pages": 1}

assert response.json()["list"] == [token_info]
Loading

0 comments on commit 8b2c77d

Please sign in to comment.