diff --git a/battleship/server/auth.py b/battleship/server/auth.py index 6a8a2b6..b4c64b9 100644 --- a/battleship/server/auth.py +++ b/battleship/server/auth.py @@ -1,5 +1,6 @@ import asyncio from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone from enum import auto from functools import partial from random import choice @@ -10,7 +11,7 @@ import auth0 # type: ignore[import-untyped] import jwt from auth0.authentication import Database, GetToken # type: ignore[import-untyped] -from auth0.management import Auth0 as _Auth0 # type: ignore[import-untyped] +from auth0.management import Auth0 # type: ignore[import-untyped] from loguru import logger from battleship.server.config import Config @@ -121,6 +122,9 @@ async def assign_role(self, user_id: str, role: UserRole) -> None: class Auth0API: + TOKEN_REFRESH_LEEWAY = timedelta(seconds=60) + TOKEN_WATCH_INTERVAL = timedelta(seconds=10) + def __init__(self, domain: str, client_id: str, client_secret: str, realm: str, audience: str): self.domain = domain self.client_id = client_id @@ -138,7 +142,11 @@ def __init__(self, domain: str, client_id: str, client_secret: str, realm: str, self.client_id, self.client_secret, ) - self._mgmt: _Auth0 | None = None + + token, expires_at = self._fetch_management_token(self.audience) + self.mgmt = Auth0(self.domain, token) + self.mgmt_token_expires_at = expires_at + self._mgmt_token_watcher_task = asyncio.create_task(self._mgmt_token_watcher()) @classmethod def from_config(cls, config: Config) -> "Auth0API": @@ -151,10 +159,13 @@ def from_config(cls, config: Config) -> "Auth0API": ) @property - def mgmt(self) -> _Auth0: - if self._mgmt is None: - self._mgmt = _Auth0(self.domain, self._fetch_management_token(self.audience)) - return self._mgmt + def mgmt_token_expires_at(self) -> datetime: + return self._mgmt_token_expires_at + + @mgmt_token_expires_at.setter + def mgmt_token_expires_at(self, expires_at: datetime) -> None: + logger.info("Set new Auth0 management token. Expires at {0}.", expires_at) + self._mgmt_token_expires_at = expires_at async def add_roles(self, user_id: str, *roles: str) -> JSONPayload: func = partial(self.mgmt.users.add_roles, id=user_id, roles=roles) @@ -190,9 +201,27 @@ async def refresh_token(self, refresh_token: str) -> JSONPayload: data = await asyncio.to_thread(func) return cast(JSONPayload, data) - def _fetch_management_token(self, audience: str) -> str: + def _fetch_management_token(self, audience: str) -> tuple[str, datetime]: data = self.gettoken.client_credentials(audience) - return cast(str, data["access_token"]) + token, expires_in = cast(str, data["access_token"]), cast(int, data["expires_in"]) + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + return token, expires_at + + @logger.catch + async def _mgmt_token_watcher(self) -> None: + watch_interval = self.TOKEN_WATCH_INTERVAL.total_seconds() + logger.info("Run Auth0 management token watcher every {0} seconds.", watch_interval) + + while True: + await asyncio.sleep(watch_interval) + + now = datetime.now(timezone.utc) + + if now > (self._mgmt_token_expires_at - self.TOKEN_REFRESH_LEEWAY): + logger.info("Auth0 management token expires soon. Update it now.") + token, expires_at = self._fetch_management_token(self.audience) + self.mgmt = Auth0(self.domain, token) + self.mgmt_token_expires_at = expires_at def _make_random_nickname(postfix_length: int = 7) -> str: