Skip to content

Commit

Permalink
Make Auth0 management token refresh automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
Klavionik committed Aug 30, 2024
1 parent 86c8182 commit b8232d0
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions battleship/server/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime, timedelta, timezone
from enum import auto
from functools import partial
from random import choice
Expand All @@ -10,7 +11,7 @@
import auth0 # type: ignore[import-untyped]
import jwt
from auth0.authentication import Database, GetToken # type: ignore[import-untyped]
from auth0.management import Auth0 as _Auth0 # type: ignore[import-untyped]
from auth0.management import Auth0 # type: ignore[import-untyped]
from loguru import logger

from battleship.server.config import Config
Expand Down Expand Up @@ -121,6 +122,9 @@ async def assign_role(self, user_id: str, role: UserRole) -> None:


class Auth0API:
TOKEN_REFRESH_LEEWAY = timedelta(seconds=60)
TOKEN_WATCH_INTERVAL = timedelta(seconds=10)

def __init__(self, domain: str, client_id: str, client_secret: str, realm: str, audience: str):
self.domain = domain
self.client_id = client_id
Expand All @@ -138,7 +142,11 @@ def __init__(self, domain: str, client_id: str, client_secret: str, realm: str,
self.client_id,
self.client_secret,
)
self._mgmt: _Auth0 | None = None

token, expires_at = self._fetch_management_token(self.audience)
self.mgmt = Auth0(self.domain, token)
self.mgmt_token_expires_at = expires_at
self._mgmt_token_watcher_task = asyncio.create_task(self._mgmt_token_watcher())

@classmethod
def from_config(cls, config: Config) -> "Auth0API":
Expand All @@ -151,10 +159,13 @@ def from_config(cls, config: Config) -> "Auth0API":
)

@property
def mgmt(self) -> _Auth0:
if self._mgmt is None:
self._mgmt = _Auth0(self.domain, self._fetch_management_token(self.audience))
return self._mgmt
def mgmt_token_expires_at(self) -> datetime:
return self._mgmt_token_expires_at

@mgmt_token_expires_at.setter
def mgmt_token_expires_at(self, expires_at: datetime) -> None:
logger.info("Set new Auth0 management token. Expires at {0}.", expires_at)
self._mgmt_token_expires_at = expires_at

async def add_roles(self, user_id: str, *roles: str) -> JSONPayload:
func = partial(self.mgmt.users.add_roles, id=user_id, roles=roles)
Expand Down Expand Up @@ -190,9 +201,27 @@ async def refresh_token(self, refresh_token: str) -> JSONPayload:
data = await asyncio.to_thread(func)
return cast(JSONPayload, data)

def _fetch_management_token(self, audience: str) -> str:
def _fetch_management_token(self, audience: str) -> tuple[str, datetime]:
data = self.gettoken.client_credentials(audience)
return cast(str, data["access_token"])
token, expires_in = cast(str, data["access_token"]), cast(int, data["expires_in"])
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
return token, expires_at

@logger.catch
async def _mgmt_token_watcher(self) -> None:
watch_interval = self.TOKEN_WATCH_INTERVAL.total_seconds()
logger.info("Run Auth0 management token watcher every {0} seconds.", watch_interval)

while True:
await asyncio.sleep(watch_interval)

now = datetime.now(timezone.utc)

if now > (self._mgmt_token_expires_at - self.TOKEN_REFRESH_LEEWAY):
logger.info("Auth0 management token expires soon. Update it now.")
token, expires_at = self._fetch_management_token(self.audience)
self.mgmt = Auth0(self.domain, token)
self.mgmt_token_expires_at = expires_at


def _make_random_nickname(postfix_length: int = 7) -> str:
Expand Down

0 comments on commit b8232d0

Please sign in to comment.