From b71b841ea261bb572b9a905e7fa441ef02abf384 Mon Sep 17 00:00:00 2001 From: Jacob Magnusson Date: Fri, 26 Mar 2021 23:14:36 +0100 Subject: [PATCH] v0.3.0 - new features, breaking API changes and 100% test coverage ### Added - OAuth2 and OIDC can now be enabled by just passing an OIDC discovery URL to `FastAPISecurity.init_oauth2_through_oidc` - Cached data is now used for JWKS and OIDC endpoints in case the "refresh requests" fail. ### Changed - `UserPermission` objects are now created via `FastAPISecurity.user_permission`. - `FastAPISecurity.init` was split into three distinct methods: `.init_basic_auth`, `.init_oauth2_through_oidc` and `.init_oauth2_through_jwks`. - Broke out the `permission_overrides` argument from the old `.init` method and added a distinct method for adding new overrides `add_permission_overrides`. This method can be called multiple times. - The dependency `FastAPISecurity.has_permission` and `FastAPISecurity.user_with_permissions` has been replaced by `FastAPISecurity.user_holding`. API is the same (takes a variable number of UserPermission arguments, i.e. compatible with both). ### Removed - Remove `app` argument to the `FastAPISecurity.init...` methods (it wasn't used before) - The global permissions registry has been removed. Now there should be no global mutable state left. --- CHANGELOG.md | 26 ++ examples/app1/README.md | 8 +- examples/app1/app.py | 32 ++- examples/app1/settings.py | 14 +- fastapi_security/__init__.py | 1 - fastapi_security/api.py | 130 ++++++---- fastapi_security/basic.py | 12 +- fastapi_security/entities.py | 35 +-- fastapi_security/oauth2.py | 58 +++-- fastapi_security/oidc.py | 51 ++-- fastapi_security/permissions.py | 17 +- fastapi_security/registry.py | 15 -- pyproject.toml | 4 +- tests/conftest.py | 3 - tests/examples/__init__.py | 0 tests/examples/conftest.py | 0 tests/examples/helpers.py | 41 +++ tests/examples/test_app1.py | 89 +++++++ tests/helpers/__init__.py | 0 tests/{jwks_helpers.py => helpers/jwks.py} | 38 ++- tests/helpers/oidc.py | 2 + tests/integration/__init__.py | 0 .../test_api.py} | 0 tests/integration/test_basic_auth.py | 67 +++++ tests/{ => integration}/test_oauth2.py | 14 +- tests/integration/test_oidc.py | 77 ++++++ .../integration/test_permission_overrides.py | 47 ++++ tests/integration/test_permissions.py | 99 +++++++ tests/integration/test_user_data.py | 243 ++++++++++++++++++ tests/test_basic_auth.py | 37 --- tests/test_permissions.py | 60 ----- tests/test_user_data.py | 115 --------- tests/unit/__init__.py | 0 tests/unit/test_entities.py | 93 +++++++ tests/unit/test_oauth2.py | 174 +++++++++++++ tests/unit/test_oidc.py | 84 ++++++ 36 files changed, 1280 insertions(+), 406 deletions(-) create mode 100644 CHANGELOG.md delete mode 100644 fastapi_security/registry.py create mode 100644 tests/examples/__init__.py create mode 100644 tests/examples/conftest.py create mode 100644 tests/examples/helpers.py create mode 100644 tests/examples/test_app1.py create mode 100644 tests/helpers/__init__.py rename tests/{jwks_helpers.py => helpers/jwks.py} (67%) create mode 100644 tests/helpers/oidc.py create mode 100644 tests/integration/__init__.py rename tests/{test_configuration.py => integration/test_api.py} (100%) create mode 100644 tests/integration/test_basic_auth.py rename tests/{ => integration}/test_oauth2.py (78%) create mode 100644 tests/integration/test_oidc.py create mode 100644 tests/integration/test_permission_overrides.py create mode 100644 tests/integration/test_permissions.py create mode 100644 tests/integration/test_user_data.py delete mode 100644 tests/test_basic_auth.py delete mode 100644 tests/test_permissions.py delete mode 100644 tests/test_user_data.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_entities.py create mode 100644 tests/unit/test_oauth2.py create mode 100644 tests/unit/test_oidc.py diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..6cba53f --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,26 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +- Nothing + +## [0.3.0](https://github.com/jmagnusson/fastapi-security/compare/v0.2.0...v0.3.0) - 2021-03-26 + +### Added + +- OAuth2 and OIDC can now be enabled by just passing an OIDC discovery URL to `FastAPISecurity.init_oauth2_through_oidc` +- Cached data is now used for JWKS and OIDC endpoints in case the "refresh requests" fail. + +### Changed +- `UserPermission` objects are now created via `FastAPISecurity.user_permission`. +- `FastAPISecurity.init` was split into three distinct methods: `.init_basic_auth`, `.init_oauth2_through_oidc` and `.init_oauth2_through_jwks`. +- Broke out the `permission_overrides` argument from the old `.init` method and added a distinct method for adding new overrides `add_permission_overrides`. This method can be called multiple times. +- The dependency `FastAPISecurity.has_permission` and `FastAPISecurity.user_with_permissions` has been replaced by `FastAPISecurity.user_holding`. API is the same (takes a variable number of UserPermission arguments, i.e. compatible with both). + +### Removed +- Remove `app` argument to the `FastAPISecurity.init...` methods (it wasn't used before) +- The global permissions registry has been removed. Now there should be no global mutable state left. diff --git a/examples/app1/README.md b/examples/app1/README.md index 97afb9d..9e98765 100644 --- a/examples/app1/README.md +++ b/examples/app1/README.md @@ -1,12 +1,14 @@ -# FastAPI Security Example App +# FastAPI-Security Example App To try out: ```bash pip install fastapi-security uvicorn +export OIDC_DISCOVERY_URL='https://my-auth0-tenant.eu.auth0.com/.well-known/openid-configuration' +export OAUTH2_AUDIENCES='["my-audience"]' export BASIC_AUTH_CREDENTIALS='[{"username": "user1", "password": "test"}]' -export AUTH_JWKS_URL='https://my-auth0-tenant.eu.auth0.com/.well-known/jwks.json' -export AUTH_AUDIENCES='["my-audience"]' export PERMISSION_OVERRIDES='{"user1": ["products:create"]}' uvicorn app1:app ``` + +You would need to replace the `my-auth0-tenant.eu.auth0.com` part to make it work. diff --git a/examples/app1/app.py b/examples/app1/app.py index 1f3464e..1f604f3 100644 --- a/examples/app1/app.py +++ b/examples/app1/app.py @@ -3,7 +3,7 @@ from fastapi import Depends, FastAPI -from fastapi_security import FastAPISecurity, User, UserPermission +from fastapi_security import FastAPISecurity, User from . import db from .models import Product @@ -15,18 +15,26 @@ security = FastAPISecurity() -security.init( - app, - basic_auth_credentials=settings.basic_auth_credentials, - jwks_url=settings.oauth2_jwks_url, - audiences=settings.oauth2_audiences, - oidc_discovery_url=settings.oidc_discovery_url, - permission_overrides=settings.permission_overrides, -) +if settings.basic_auth_credentials: + security.init_basic_auth(settings.basic_auth_credentials) + +if settings.oidc_discovery_url: + security.init_oauth2_through_oidc( + settings.oidc_discovery_url, + audiences=settings.oauth2_audiences, + ) +elif settings.oauth2_jwks_url: + security.init_oauth2_through_jwks( + settings.oauth2_jwks_url, + audiences=settings.oauth2_audiences, + ) + +security.add_permission_overrides(settings.permission_overrides or {}) + logger = logging.getLogger(__name__) -create_product_perm = UserPermission("products:create") +create_product_perm = security.user_permission("products:create") @app.get("/users/me") @@ -41,10 +49,10 @@ def get_user_permissions(user: User = Depends(security.authenticated_user_or_401 return user.permissions -@app.post("/products", response_model=Product) +@app.post("/products", response_model=Product, status_code=201) async def create_product( product: Product, - user: User = Depends(security.user_with_permissions(create_product_perm)), + user: User = Depends(security.user_holding(create_product_perm)), ): """Create product diff --git a/examples/app1/settings.py b/examples/app1/settings.py index eea31a1..6f4521e 100644 --- a/examples/app1/settings.py +++ b/examples/app1/settings.py @@ -1,20 +1,20 @@ from functools import lru_cache -from typing import Dict, List, Optional +from typing import List, Optional -from fastapi.security import HTTPBasicCredentials from pydantic import BaseSettings +from fastapi_security import HTTPBasicCredentials, PermissionOverrides + __all__ = ("get_settings",) class _Settings(BaseSettings): - oauth2_jwks_url: Optional[ - str - ] = None # TODO: This could be retrieved from OIDC discovery URL + # NOTE: You only need to supply `oidc_discovery_url` (preferred) OR `oauth2_jwks_url` + oidc_discovery_url: Optional[str] = None + oauth2_jwks_url: Optional[str] = None oauth2_audiences: Optional[List[str]] = None basic_auth_credentials: Optional[List[HTTPBasicCredentials]] = None - oidc_discovery_url: Optional[str] = None - permission_overrides: Optional[Dict[str, List[str]]] = None + permission_overrides: PermissionOverrides = {} @lru_cache() diff --git a/fastapi_security/__init__.py b/fastapi_security/__init__.py index 1947b91..6b55775 100644 --- a/fastapi_security/__init__.py +++ b/fastapi_security/__init__.py @@ -4,5 +4,4 @@ from .oauth2 import * # noqa from .oidc import * # noqa from .permissions import * # noqa -from .registry import * # noqa from .schemes import * # noqa diff --git a/fastapi_security/api.py b/fastapi_security/api.py index 4c2dcfa..6f8f546 100644 --- a/fastapi_security/api.py +++ b/fastapi_security/api.py @@ -1,17 +1,15 @@ import logging -from typing import Callable, Dict, List, Optional +from typing import Callable, Iterable, List, Optional, Type -from fastapi import Depends, FastAPI, HTTPException -from fastapi.security import HTTPBasicCredentials +from fastapi import Depends, HTTPException from fastapi.security.http import HTTPAuthorizationCredentials -from . import registry -from .basic import BasicAuthValidator +from .basic import BasicAuthValidator, IterableOfHTTPBasicCredentials from .entities import AuthMethod, User, UserAuth, UserInfo from .exceptions import AuthNotConfigured from .oauth2 import Oauth2JwtAccessTokenValidator from .oidc import OpenIdConnectDiscovery -from .permissions import UserPermission +from .permissions import PermissionOverrides, UserPermission from .schemes import http_basic_scheme, jwt_bearer_scheme logger = logging.getLogger(__name__) @@ -25,34 +23,60 @@ class FastAPISecurity: Must be initialized after object creation via the `init()` method. """ - def __init__(self): + def __init__(self, *, user_permission_class: Type[UserPermission] = UserPermission): self.basic_auth = BasicAuthValidator() self.oauth2_jwt = Oauth2JwtAccessTokenValidator() self.oidc_discovery = OpenIdConnectDiscovery() - self._permission_overrides = None - - def init( - self, - app: FastAPI, - basic_auth_credentials: List[HTTPBasicCredentials] = None, - permission_overrides: Dict[str, List[str]] = None, - jwks_url: str = None, - audiences: List[str] = None, - oidc_discovery_url: str = None, + self._permission_overrides: PermissionOverrides = {} + self._user_permission_class = user_permission_class + self._all_permissions: List[UserPermission] = [] + self._oauth2_init_through_oidc = False + self._oauth2_audiences: List[str] = [] + + def init_basic_auth(self, basic_auth_credentials: IterableOfHTTPBasicCredentials): + self.basic_auth.init(basic_auth_credentials) + + def init_oauth2_through_oidc( + self, oidc_discovery_url: str, *, audiences: Iterable[str] = None ): - self._permission_overrides = permission_overrides + """Initialize OIDC and OAuth2 authentication/authorization - if basic_auth_credentials: - # Initialize basic auth (superusers with all permissions) - self.basic_auth.init(basic_auth_credentials) + OAuth2 JWKS URL is lazily fetched from the OIDC endpoint once it's needed for the first time. + + This method is preferred over `init_oauth2_through_jwks` as you get all the + benefits of OIDC, with less configuration supplied. + """ + self._oauth2_audiences.extend(audiences or []) + self.oidc_discovery.init(oidc_discovery_url) - if jwks_url: - # # Initialize OAuth 2.0 - user permissions are required for all flows - # # except Client Credentials - self.oauth2_jwt.init(jwks_url, audiences=audiences or []) + def init_oauth2_through_jwks( + self, jwks_uri: str, *, audiences: Iterable[str] = None + ): + """Initialize OAuth2 + + It's recommended to use `init_oauth2_through_oidc` instead. + """ + self._oauth2_audiences.extend(audiences or []) + self.oauth2_jwt.init(jwks_uri, audiences=self._oauth2_audiences) - if oidc_discovery_url and self.oauth2_jwt.is_configured(): - self.oidc_discovery.init(oidc_discovery_url) + def add_permission_overrides(self, overrides: PermissionOverrides): + """Add wildcard or specific permissions to basic auth and/or OAuth2 users + + Example: + security = FastAPISecurity() + create_product = security.user_permission("products:create") + + # Give all permissions to the user johndoe + security.add_permission_overrides({"johndoe": "*"}) + + # Give the OAuth2 user `7ZmI5ycgNHeZ9fHPZZwTNbIRd9Ectxca@clients` the + # "products:create" permission. + security.add_permission_overrides({ + "7ZmI5ycgNHeZ9fHPZZwTNbIRd9Ectxca@clients": ["products:create"], + }) + + """ + self._permission_overrides.update(overrides) @property def user(self) -> Callable: @@ -79,7 +103,7 @@ def user_with_info(self) -> Callable: """Dependency that returns User object with user info, authenticated or not""" async def dependency(user_auth: UserAuth = Depends(self._user_auth)): - if user_auth.is_oauth2(): + if user_auth.is_oauth2() and user_auth.access_token: info = await self.oidc_discovery.get_user_info(user_auth.access_token) else: info = UserInfo.make_dummy() @@ -94,7 +118,7 @@ def authenticated_user_with_info_or_401(self) -> Callable: """ async def dependency(user_auth: UserAuth = Depends(self._user_auth_or_401)): - if user_auth.is_oauth2(): + if user_auth.is_oauth2() and user_auth.access_token: info = await self.oidc_discovery.get_user_info(user_auth.access_token) else: info = UserInfo.make_dummy() @@ -102,18 +126,12 @@ async def dependency(user_auth: UserAuth = Depends(self._user_auth_or_401)): return dependency - def has_permission(self, permission: UserPermission) -> Callable: - """Dependency that raises HTTP403 if the user is missing the given permission""" + def user_permission(self, identifier: str) -> UserPermission: + perm = self._user_permission_class(identifier) + self._all_permissions.append(perm) + return perm - async def dependency( - user: User = Depends(self.authenticated_user_or_401), - ) -> User: - self._has_permission_or_raise_forbidden(user, permission) - return user - - return dependency - - def user_with_permissions(self, *permissions: UserPermission) -> Callable: + def user_holding(self, *permissions: UserPermission) -> Callable: """Dependency that returns the user if it has the given permissions, otherwise raises HTTP403 """ @@ -137,12 +155,17 @@ async def dependency( ), http_credentials: HTTPAuthorizationCredentials = Depends(http_basic_scheme), ) -> Optional[UserAuth]: - if not any( - [self.oauth2_jwt.is_configured(), self.basic_auth.is_configured()] - ): + oidc_configured = self.oidc_discovery.is_configured() + oauth2_configured = self.oauth2_jwt.is_configured() + basic_auth_configured = self.basic_auth.is_configured() + if not any([oidc_configured, oauth2_configured, basic_auth_configured]): raise AuthNotConfigured() + if oidc_configured and not oauth2_configured: + jwks_uri = await self.oidc_discovery.get_jwks_uri() + self.init_oauth2_through_jwks(jwks_uri) + if bearer_credentials is not None: bearer_token = bearer_credentials.credentials access_token = await self.oauth2_jwt.parse(bearer_token) @@ -199,16 +222,21 @@ def _raise_forbidden(self, required_permission: str): ) def _maybe_override_permissions(self, user_auth: UserAuth) -> UserAuth: - overrides = (self._permission_overrides or {}).get(user_auth.subject) - - if overrides is None: - return user_auth + overrides = self._permission_overrides.get(user_auth.subject) - all_permissions = registry.get_all_permissions() + all_permission_identifiers = [p.identifier for p in self._all_permissions] - if "*" in overrides: - return user_auth.with_permissions(all_permissions) + if overrides is None: + return user_auth.with_permissions( + [ + incoming_id + for incoming_id in user_auth.permissions + if incoming_id in all_permission_identifiers + ] + ) + elif "*" in overrides: + return user_auth.with_permissions(all_permission_identifiers) else: return user_auth.with_permissions( - [p for p in overrides if p in all_permissions] + [p for p in overrides if p in all_permission_identifiers] ) diff --git a/fastapi_security/basic.py b/fastapi_security/basic.py index ae42dbb..b4e8f32 100644 --- a/fastapi_security/basic.py +++ b/fastapi_security/basic.py @@ -1,18 +1,18 @@ import secrets -from typing import Dict, List, Union +from typing import Dict, Iterable, List, Union from fastapi.security.http import HTTPBasicCredentials -__all__ = () +__all__ = ("HTTPBasicCredentials",) -ListOfCredentials = List[Union[HTTPBasicCredentials, Dict]] +IterableOfHTTPBasicCredentials = Iterable[Union[HTTPBasicCredentials, Dict]] class BasicAuthValidator: def __init__(self): self._credentials = [] - def init(self, credentials: ListOfCredentials): + def init(self, credentials: IterableOfHTTPBasicCredentials): self._credentials = self._make_credentials(credentials) def is_configured(self) -> bool: @@ -29,7 +29,9 @@ def validate(self, credentials: HTTPBasicCredentials) -> bool: for c in self._credentials ) - def _make_credentials(self, credentials: ListOfCredentials): + def _make_credentials( + self, credentials: IterableOfHTTPBasicCredentials + ) -> List[HTTPBasicCredentials]: return [ c if isinstance(c, HTTPBasicCredentials) else HTTPBasicCredentials(**c) for c in credentials diff --git a/fastapi_security/entities.py b/fastapi_security/entities.py index 9b99f49..a253df1 100644 --- a/fastapi_security/entities.py +++ b/fastapi_security/entities.py @@ -2,9 +2,7 @@ from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, root_validator, validator - -from . import registry +from pydantic import BaseModel, Field, validator __all__ = ("User",) @@ -29,9 +27,6 @@ class JwtAccessToken(BaseModel): description="Permissions (auth0 specific, intended for first-party app authorization)", ) raw: str = Field(..., description="The raw access token") - _extra: Dict[str, Any] = Field( - {}, description="Any extra fields that were provided in the access token" - ) @validator("aud", pre=True, always=True) def aud_to_list(cls, v): @@ -51,21 +46,6 @@ def permissions_to_list(cls, v): def is_client_credentials(self): return self.gty == "client-credentials" - @root_validator(pre=True) - def set_extra_field(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Ensure that any additional passed in data is set on the `extra` field""" - extra: Dict[str, Any] = {} - new_values = {"_extra": extra} - model_keys = cls.__fields__.keys() - - for k, v in values.items(): - if k in model_keys: - new_values[k] = v - else: - extra[k] = v - - return new_values - class AuthMethod(str, Enum): none = "none" @@ -107,15 +87,6 @@ class UserAuth(BaseModel): scopes: List[str] = [] permissions: List[str] = [] access_token: Optional[str] = None - _extra: Dict[str, Any] = {} - - @validator("permissions", pre=True, always=True) - def only_add_valid_permissions(cls, v, values): - if v: - all_permissions = registry.get_all_permissions() - return [e for e in v if e in all_permissions] - else: - return v def is_authenticated(self) -> bool: return self.auth_method is not AuthMethod.none @@ -151,7 +122,6 @@ def from_jwt_access_token(cls, access_token: JwtAccessToken) -> "UserAuth": scopes=access_token.scope, permissions=access_token.permissions, access_token=access_token.raw, - _extra=access_token._extra, ) @classmethod @@ -181,6 +151,3 @@ def has_permission(self, permission: str) -> bool: def without_access_token(self) -> "User": return self.copy(deep=True, exclude={"auth": {"access_token"}}) - - def without_extra(self) -> "User": - return self.copy(deep=True, exclude={"auth": {"extra"}}) diff --git a/fastapi_security/oauth2.py b/fastapi_security/oauth2.py index 79409c4..dc28c1e 100644 --- a/fastapi_security/oauth2.py +++ b/fastapi_security/oauth2.py @@ -1,7 +1,7 @@ import json import logging -from datetime import datetime, timedelta -from typing import Any, Dict, Iterable, Optional +import time +from typing import Any, Dict, Iterable, List, Optional import aiohttp import jwt @@ -34,11 +34,11 @@ class Oauth2JwtAccessTokenValidator: """ def __init__(self): - self._jwks_url = None - self._audiences = None - self._jwks_kid_mapping = None - self._jwks_cache_period = DEFAULT_JWKS_RESPONSE_CACHE_PERIOD - self._jwks_cached_at = None + self._jwks_url: Optional[str] = None + self._audiences: Optional[List[str]] = None + self._jwks_kid_mapping: Dict[str, _RSAPublicKey] = None + self._jwks_cache_period: float = float(DEFAULT_JWKS_RESPONSE_CACHE_PERIOD) + self._jwks_cached_at: Optional[float] = None def init( self, @@ -60,7 +60,7 @@ def init( """ self._jwks_url = jwks_url - self._jwks_cache_period = jwks_cache_period + self._jwks_cache_period = float(jwks_cache_period) self._audiences = list(audiences) def is_configured(self) -> bool: @@ -95,7 +95,7 @@ async def parse(self, access_token: str) -> Optional[JwtAccessToken]: try: public_key = await self._get_public_key(token_kid) except KeyError: - logger.debug("No matching kid for JWT token") + logger.debug("No matching `kid` for JWT token") return None try: @@ -107,7 +107,7 @@ async def parse(self, access_token: str) -> Optional[JwtAccessToken]: try: parsed_access_token = JwtAccessToken(**decoded, raw=access_token) except ValidationError as ex: - logger.debug(f"Failed to parse JWT token with validation error: {ex!r}") + logger.debug(f"Failed to parse JWT token with {ex}") return None return parsed_access_token @@ -117,18 +117,29 @@ async def _get_public_key(self, kid: str) -> _RSAPublicKey: return mapping[kid] async def _get_jwks_kid_mapping(self) -> Dict[str, _RSAPublicKey]: - if self._jwks_cached_at is None or ( - (datetime.utcnow() - self._jwks_cached_at) - > timedelta(seconds=self._jwks_cache_period) + if ( + self._jwks_cached_at is None + or (time.monotonic() - self._jwks_cached_at) > self._jwks_cache_period ): - jwks_data = await self._fetch_jwks_data() - self._jwks_kid_mapping = { - k["kid"]: RSAAlgorithm.from_jwk(json.dumps(k)) - for k in jwks_data["keys"] - if k["kty"] == "RSA" and k["alg"] == "RS256" - } - self._jwks_cached_at = datetime.utcnow() - assert len(self._jwks_kid_mapping) > 0 + try: + jwks_data = await self._fetch_jwks_data() + except Exception as ex: + if self._jwks_kid_mapping is None: + raise + else: + logger.info( + f"Failed to refresh JWKS kid mapping, re-using old data. " + f"Exception was: {ex!r}" + ) + self._jwks_cached_at = time.monotonic() + else: + self._jwks_kid_mapping = { + k["kid"]: RSAAlgorithm.from_jwk(json.dumps(k)) + for k in jwks_data["keys"] + if k["kty"] == "RSA" and k["alg"] == "RS256" + } + self._jwks_cached_at = time.monotonic() + assert len(self._jwks_kid_mapping) > 0 return self._jwks_kid_mapping @@ -139,10 +150,7 @@ async def _fetch_jwks_data(self): async with aiohttp.ClientSession(timeout=timeout) as session: async with session.get(self._jwks_url) as response: - try: - return await response.json() - except Exception as ex: - raise RuntimeError(f"Failed to load JWKS data with exception: {ex}") + return await response.json() def _decode_jwt_token( self, public_key: _RSAPublicKey, access_token: str diff --git a/fastapi_security/oidc.py b/fastapi_security/oidc.py index 24b9c7e..b368cbe 100644 --- a/fastapi_security/oidc.py +++ b/fastapi_security/oidc.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timedelta +import time from typing import Any, Dict, Optional import aiohttp @@ -16,8 +16,12 @@ class OpenIdConnectDiscovery: """Retrieve info from OpenID Connect (OIDC) endpoints""" def __init__(self): - self._discovery_url = None - self._discovery_cache_period = DEFAULT_DISCOVERY_RESPONSE_CACHE_PERIOD + self._discovery_url: Optional[str] = None + self._discovery_data_cached_at: Optional[float] = None + self._discovery_cache_period: float = float( + DEFAULT_DISCOVERY_RESPONSE_CACHE_PERIOD + ) + self._discovery_data: Optional[Dict[str, Any]] = None def init( self, @@ -35,9 +39,7 @@ def init( How many seconds to cache the OpenID Discovery endpoint response. Defaults to 1 hour. """ self._discovery_url = discovery_url - self._discovery_cache_period = discovery_cache_period - self._discovery_data_cached_at: Optional[datetime] = None - self._discovery_data: Optional[Dict[str, Any]] = None + self._discovery_cache_period = float(discovery_cache_period) def is_configured(self) -> bool: return bool(self._discovery_url) @@ -63,9 +65,14 @@ async def get_user_info(self, access_token: str) -> Optional[UserInfo]: else: return UserInfo.from_oidc_endpoint(user_info) + async def get_jwks_uri(self) -> str: + """Get or fetch the JWKS URI""" + data = await self.get_discovery_data() + return data["jwks_uri"] + async def _fetch_user_info(self, access_token: str) -> Optional[Dict[str, Any]]: timeout = aiohttp.ClientTimeout(total=10) - url = await self._get_user_info_endpoint() + url = await self.get_user_info_endpoint() headers = {"Authorization": f"Bearer {access_token}"} logger.debug(f"Fetching user info from {url}") @@ -80,29 +87,41 @@ async def _fetch_user_info(self, access_token: str) -> Optional[Dict[str, Any]]: ) return None - async def _get_user_info_endpoint(self) -> str: - data = await self._get_discovery_data() + async def get_user_info_endpoint(self) -> str: + data = await self.get_discovery_data() return data["userinfo_endpoint"] - async def _get_discovery_data(self) -> Dict[str, Any]: + async def get_discovery_data(self) -> Dict[str, Any]: if ( self._discovery_data is None or self._discovery_data_cached_at is None or ( - (datetime.utcnow() - self._discovery_data_cached_at) - > timedelta(seconds=self._discovery_cache_period) + (time.monotonic() - self._discovery_data_cached_at) + > self._discovery_cache_period ) ): - self._discovery_data = await self._fetch_discovery_data() - self._discovery_data_cached_at = datetime.utcnow() + try: + self._discovery_data = await self._fetch_discovery_data() + except Exception as ex: + if self._discovery_data is None: + raise + else: + logger.info( + f"Failed to refresh OIDC discovery data, re-using old data. " + f"Exception was: {ex!r}" + ) + self._discovery_data_cached_at = time.monotonic() + else: + self._discovery_data_cached_at = time.monotonic() return self._discovery_data async def _fetch_discovery_data(self) -> Dict[str, Any]: timeout = aiohttp.ClientTimeout(total=10) + assert self._discovery_url, "No OIDC discovery URL specified" logger.debug(f"Fetching OIDC discovery data from {self._discovery_url}") - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(self._discovery_url) as response: + async with aiohttp.ClientSession(timeout=timeout, raise_for_status=True) as s: + async with s.get(self._discovery_url) as response: return await response.json() diff --git a/fastapi_security/permissions.py b/fastapi_security/permissions.py index 297987b..24e03c7 100644 --- a/fastapi_security/permissions.py +++ b/fastapi_security/permissions.py @@ -1,6 +1,9 @@ -from . import registry +from typing import Dict, Iterable, Union -__all__ = ("UserPermission",) +__all__ = ("PermissionOverrides",) + + +PermissionOverrides = Dict[str, Union[str, Iterable[str]]] class UserPermission: @@ -8,13 +11,14 @@ class UserPermission: Creating a new permission is done like this: - create_item_permission = UserPermission("item:create") + security = FastAPISecurity() + create_item_permission = security.user_permission("item:create") Usage: @app.post( "/products", - dependencies=[Depends(security.has_permission(create_item_permission))] + dependencies=[Depends(security.user_holding(create_item_permission))] ) def create_product(...): ... @@ -22,7 +26,7 @@ def create_product(...): Or: @app.post("/products") def create_product( - user: Depends(security.user_with_permissions(create_item_permission)) + user: Depends(security.user_holding(create_item_permission)) ): ... @@ -30,10 +34,9 @@ def create_product( def __init__(self, identifier: str): self.identifier = identifier - registry.add_permission(identifier) def __str__(self): return self.identifier def __repr__(self): - return f"" + return f"{self.__class__.__name__}({self.identifier})" diff --git a/fastapi_security/registry.py b/fastapi_security/registry.py deleted file mode 100644 index 751e719..0000000 --- a/fastapi_security/registry.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import List - -__all__ = () - - -_permissions_registry: List[str] = [] - - -def add_permission(permission: str): - if permission not in _permissions_registry: - _permissions_registry.append(permission) - - -def get_all_permissions() -> List[str]: - return _permissions_registry.copy() diff --git a/pyproject.toml b/pyproject.toml index 46aa834..aaa033b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "fastapi-security" -version = "0.2.0" +version = "0.3.0" description = "Add authentication and authorization to your FastAPI app via dependencies." authors = ["Jacob Magnusson "] license = "MIT" @@ -49,6 +49,8 @@ requests = "^2.25.1" mypy = "^0.812" flake8 = "^3.9.0" mkdocs-material = "^7.0.6" +pytest-asyncio = "^0.14.0" +uvicorn = "^0.13.4" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/conftest.py b/tests/conftest.py index 7f36ce2..dce20ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,9 @@ from fastapi import FastAPI from starlette.testclient import TestClient -from fastapi_security.registry import _permissions_registry - @pytest.fixture def app(): - _permissions_registry.clear() # TODO: Make permissions context local! return FastAPI() diff --git a/tests/examples/__init__.py b/tests/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/examples/conftest.py b/tests/examples/conftest.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/examples/helpers.py b/tests/examples/helpers.py new file mode 100644 index 0000000..638503a --- /dev/null +++ b/tests/examples/helpers.py @@ -0,0 +1,41 @@ +import os +import socket +import subprocess +from contextlib import contextmanager +from typing import Dict + + +def available_port(ip: str = "127.0.0.1"): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind((ip, 0)) + port = s.getsockname()[1] + s.close() + return port + + +@contextmanager +def run_example_app(app_path: str, *, env: Dict[str, str] = None): + port = available_port() + proc = subprocess.Popen( + ["uvicorn", app_path, f"--port={port}"], + stderr=subprocess.PIPE, + env={**os.environ, **(env or {})}, + ) + + while True: + if proc.stderr: + line = proc.stderr.readline() + if b"Uvicorn running on" in line: + break + elif b"Traceback" in line: + lines = proc.stderr.read() + proc.terminate() + raise RuntimeError(lines.decode()) + else: + break + + try: + yield f"http://127.0.0.1:{port}" + finally: + proc.terminate() diff --git a/tests/examples/test_app1.py b/tests/examples/test_app1.py new file mode 100644 index 0000000..46a737c --- /dev/null +++ b/tests/examples/test_app1.py @@ -0,0 +1,89 @@ +from pathlib import Path + +import pytest +import requests + +from .helpers import run_example_app + +try: + import uvicorn +except ImportError: + uvicorn = None + +app1_path = Path("./examples/app1") + +pytestmark = [ + pytest.mark.skipif(not app1_path.exists(), reason="app1 example couldn't be found"), + pytest.mark.skipif(uvicorn is None, reason="`uvicorn` isn't installed"), +] + + +basic_auth_env = { + "BASIC_AUTH_CREDENTIALS": '[{"username": "user1", "password": "test"}]' +} + + +def test_users_me_basic_auth_anonymous(): + with run_example_app("examples.app1:app", env=basic_auth_env) as base_url: + resp = requests.get(f"{base_url}/users/me") + assert resp.status_code == 200 + data = resp.json() + assert data["auth"] == { + "subject": "anonymous", + "auth_method": "none", + "issuer": None, + "audience": [], + "issued_at": None, + "expires_at": None, + "scopes": [], + "permissions": [], + } + + +def test_users_me_basic_auth_authenticated(): + with run_example_app("examples.app1:app", env=basic_auth_env) as base_url: + resp = requests.get(f"{base_url}/users/me", auth=("user1", "test")) + assert resp.status_code == 200 + data = resp.json() + assert data["auth"] == { + "subject": "user1", + "auth_method": "basic_auth", + "issuer": None, + "audience": [], + "issued_at": None, + "expires_at": None, + "scopes": [], + "permissions": [], + } + + +def test_user_permissions_basic_auth_authenticated(): + with run_example_app( + "examples.app1:app", + env={**basic_auth_env, "PERMISSION_OVERRIDES": '{"user1": ["*"]}'}, + ) as base_url: + resp = requests.get(f"{base_url}/users/me/permissions", auth=("user1", "test")) + assert resp.status_code == 200 + data = resp.json() + assert data == ["products:create"] + + +def test_create_product_unauthenticated(): + with run_example_app("examples.app1:app", env=basic_auth_env) as base_url: + resp = requests.post(f"{base_url}/products") + assert resp.status_code == 401 + data = resp.json() + assert data == {"detail": "Could not validate credentials"} + + +def test_create_product_authenticated(): + with run_example_app( + "examples.app1:app", + env={**basic_auth_env, "PERMISSION_OVERRIDES": '{"user1": ["*"]}'}, + ) as base_url: + resp = requests.post( + f"{base_url}/products", auth=("user1", "test"), json={"name": "T-shirt"} + ) + assert resp.status_code == 201 + data = resp.json() + assert data == {"name": "T-shirt"} diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/jwks_helpers.py b/tests/helpers/jwks.py similarity index 67% rename from tests/jwks_helpers.py rename to tests/helpers/jwks.py index fcd2936..ebaaa9c 100644 --- a/tests/jwks_helpers.py +++ b/tests/helpers/jwks.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timedelta, timezone -from typing import Any, Dict +from typing import Any, Dict, Iterable import jwt from cryptography.hazmat.backends import default_backend as crypto_default_backend @@ -26,7 +26,7 @@ dummy_alg = "RS256" dummy_kid = "test123" -dummy_jwks_url = "https://identity-provider/.well-known/jwks.json" +dummy_jwks_uri = "https://identity-provider/.well-known/jwks.json" dummy_audience = "https://some-resource" dummy_jwks_response_data = { "keys": [ @@ -40,12 +40,16 @@ }, ], } -jwt_headers = {"alg": dummy_alg, "typ": "JWT", "kid": dummy_kid} +dummy_jwt_headers = {"alg": dummy_alg, "typ": "JWT", "kid": dummy_kid} -def make_access_token( - *, sub: str, expire_in: int = 3600, **extra: Dict[str, Any] -) -> str: +def make_access_token_data( + *, + sub: str, + expire_in: int = 3600, + delete_fields: Iterable[str] = None, + **extra: Dict[str, Any], +) -> Dict[str, Any]: utcnow = datetime.now(tz=timezone.utc) expire_at = utcnow + timedelta(seconds=expire_in) @@ -56,9 +60,29 @@ def make_access_token( data.setdefault("iss", "https://identity-provider/") data.setdefault("iat", int(utcnow.timestamp())) data.setdefault("exp", int(expire_at.timestamp())) + + for field in delete_fields or []: + if field in data: + del data[field] + + return data + + +def make_access_token( + *, + sub: str, + expire_in: int = 3600, + delete_fields: Iterable[str] = None, + headers: Dict[str, str] = None, + **extra: Dict[str, Any], +) -> str: + data = make_access_token_data( + sub=sub, expire_in=expire_in, delete_fields=delete_fields, **extra + ) + headers = headers or dummy_jwt_headers return jwt.encode( data, private_key.decode(), algorithm=dummy_alg, - headers=jwt_headers, + headers=headers, ) diff --git a/tests/helpers/oidc.py b/tests/helpers/oidc.py new file mode 100644 index 0000000..f248d26 --- /dev/null +++ b/tests/helpers/oidc.py @@ -0,0 +1,2 @@ +dummy_oidc_url = "https://oidc-provider/.well-known/openid-configuration" +dummy_userinfo_endpoint_url = "https://oidc-provider/userinfo" diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_configuration.py b/tests/integration/test_api.py similarity index 100% rename from tests/test_configuration.py rename to tests/integration/test_api.py diff --git a/tests/integration/test_basic_auth.py b/tests/integration/test_basic_auth.py new file mode 100644 index 0000000..1b69c7a --- /dev/null +++ b/tests/integration/test_basic_auth.py @@ -0,0 +1,67 @@ +from fastapi import Depends + +from fastapi_security import FastAPISecurity, HTTPBasicCredentials, User +from fastapi_security.basic import BasicAuthValidator + +from ..helpers.jwks import dummy_audience, dummy_jwks_uri + + +def test_that_basic_auth_doesnt_validate_any_credentials_if_unconfigured(): + validator = BasicAuthValidator() + creds = HTTPBasicCredentials(username="johndoe", password="123") + assert validator.validate(creds) is False + + +def test_that_uninitialized_basic_auth_doesnt_accept_any_credentials(app, client): + security = FastAPISecurity() + + @app.get("/") + def get_products(user: User = Depends(security.authenticated_user_or_401)): + return [] + + # NOTE: Not passing basic_auth_credentials, which means Basic Auth will be disabled + # NOTE: We are passing + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) + + resp = client.get("/") + assert resp.status_code == 401 + + resp = client.get("/", auth=("username", "password")) + assert resp.status_code == 401 + + +def test_that_basic_auth_rejects_incorrect_credentials(app, client): + security = FastAPISecurity() + + @app.get("/") + def get_products(user: User = Depends(security.authenticated_user_or_401)): + return [] + + credentials = [{"username": "user", "password": "pass"}] + security.init_basic_auth(credentials) + + resp = client.get("/") + assert resp.status_code == 401 + + resp = client.get("/", auth=("user", "")) + assert resp.status_code == 401 + + resp = client.get("/", auth=("", "pass")) + assert resp.status_code == 401 + + resp = client.get("/", auth=("abc", "123")) + assert resp.status_code == 401 + + +def test_that_basic_auth_accepts_correct_credentials(app, client): + security = FastAPISecurity() + + @app.get("/") + def get_products(user: User = Depends(security.authenticated_user_or_401)): + return [] + + credentials = [{"username": "user", "password": "pass"}] + security.init_basic_auth(credentials) + + resp = client.get("/", auth=("user", "pass")) + assert resp.status_code == 200 diff --git a/tests/test_oauth2.py b/tests/integration/test_oauth2.py similarity index 78% rename from tests/test_oauth2.py rename to tests/integration/test_oauth2.py index d69b3c0..177e1bb 100644 --- a/tests/test_oauth2.py +++ b/tests/integration/test_oauth2.py @@ -3,10 +3,10 @@ from fastapi_security import FastAPISecurity, User -from .jwks_helpers import ( +from ..helpers.jwks import ( dummy_audience, dummy_jwks_response_data, - dummy_jwks_url, + dummy_jwks_uri, make_access_token, ) @@ -19,7 +19,7 @@ def test_that_oauth2_rejects_incorrect_token(app, client): def get_products(user: User = Depends(security.authenticated_user_or_401)): return [] - security.init(app, jwks_url=dummy_jwks_url, audiences=[dummy_audience]) + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) resp = client.get("/") assert resp.status_code == 401 @@ -39,12 +39,12 @@ def test_that_oauth2_accepts_correct_token(app, client): def get_products(user: User = Depends(security.authenticated_user_or_401)): return [] - security.init(app, jwks_url=dummy_jwks_url, audiences=[dummy_audience]) + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) access_token = make_access_token(sub="test-subject") with aioresponses() as mock: - mock.get(dummy_jwks_url, payload=dummy_jwks_response_data) + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) resp = client.get("/", headers={"Authorization": f"Bearer {access_token}"}) @@ -59,12 +59,12 @@ def test_that_oauth2_rejects_expired_token(app, client): def get_products(user: User = Depends(security.authenticated_user_or_401)): return [] - security.init(app, jwks_url=dummy_jwks_url, audiences=[dummy_audience]) + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) access_token = make_access_token(sub="test-subject", expire_in=-1) with aioresponses() as mock: - mock.get(dummy_jwks_url, payload=dummy_jwks_response_data) + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) resp = client.get("/", headers={"Authorization": f"Bearer {access_token}"}) diff --git a/tests/integration/test_oidc.py b/tests/integration/test_oidc.py new file mode 100644 index 0000000..4526814 --- /dev/null +++ b/tests/integration/test_oidc.py @@ -0,0 +1,77 @@ +from aioresponses import aioresponses +from fastapi import Depends + +from fastapi_security import FastAPISecurity, User + +from ..helpers.jwks import ( + dummy_audience, + dummy_jwks_response_data, + dummy_jwks_uri, + make_access_token, +) +from ..helpers.oidc import dummy_oidc_url, dummy_userinfo_endpoint_url + + +def test_that_auth_can_be_enabled_through_oidc(app, client): + + security = FastAPISecurity() + + @app.get("/") + def get_products(user: User = Depends(security.authenticated_user_or_401)): + return [] + + security.init_oauth2_through_oidc(dummy_oidc_url, audiences=[dummy_audience]) + + access_token = make_access_token(sub="test-subject") + + with aioresponses() as mock: + mock.get( + dummy_oidc_url, + payload={ + "userinfo_endpoint": dummy_userinfo_endpoint_url, + "jwks_uri": dummy_jwks_uri, + }, + ) + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + mock.get(dummy_userinfo_endpoint_url, payload={"nickname": "jacobsvante"}) + + unauthenticated_resp = client.get("/") + assert unauthenticated_resp.status_code == 401 + + authenticated_resp = client.get( + "/", headers={"Authorization": f"Bearer {access_token}"} + ) + assert authenticated_resp.status_code == 200 + + +def test_that_oidc_info_is_returned(app, client): + + security = FastAPISecurity() + + @app.get("/users/me") + async def get_user_details(user: User = Depends(security.user_with_info)): + """Return user details, regardless of whether user is authenticated or not""" + return user.without_access_token() + + security.init_oauth2_through_oidc(dummy_oidc_url, audiences=[dummy_audience]) + + access_token = make_access_token(sub="test-subject") + + with aioresponses() as mock: + mock.get( + dummy_oidc_url, + payload={ + "userinfo_endpoint": dummy_userinfo_endpoint_url, + "jwks_uri": dummy_jwks_uri, + }, + ) + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + mock.get(dummy_userinfo_endpoint_url, payload={"nickname": "jacobsvante"}) + + resp = client.get( + "/users/me", headers={"Authorization": f"Bearer {access_token}"} + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["info"]["nickname"] == "jacobsvante" diff --git a/tests/integration/test_permission_overrides.py b/tests/integration/test_permission_overrides.py new file mode 100644 index 0000000..be18c6a --- /dev/null +++ b/tests/integration/test_permission_overrides.py @@ -0,0 +1,47 @@ +from fastapi import Depends + +from fastapi_security import FastAPISecurity, HTTPBasicCredentials, User + + +def test_that_explicit_permission_overrides_are_applied(app, client): + cred = HTTPBasicCredentials(username="johndoe", password="123") + + security = FastAPISecurity() + + create_product_perm = security.user_permission("products:create") + + security.init_basic_auth([cred]) + security.add_permission_overrides({"johndoe": ["products:create"]}) + + @app.post("/products") + def create_product( + user: User = Depends(security.user_holding(create_product_perm)), + ): + return {"ok": True} + + resp = client.post("/products", auth=("johndoe", "123")) + + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + +def test_that_wildcard_permission_overrides_are_applied(app, client): + cred = HTTPBasicCredentials(username="johndoe", password="123") + + security = FastAPISecurity() + + create_product_perm = security.user_permission("products:create") + + security.init_basic_auth([cred]) + security.add_permission_overrides({"johndoe": "*"}) + + @app.post("/products") + def create_product( + user: User = Depends(security.user_holding(create_product_perm)), + ): + return {"ok": True} + + resp = client.post("/products", auth=("johndoe", "123")) + + assert resp.status_code == 200 + assert resp.json() == {"ok": True} diff --git a/tests/integration/test_permissions.py b/tests/integration/test_permissions.py new file mode 100644 index 0000000..1d95362 --- /dev/null +++ b/tests/integration/test_permissions.py @@ -0,0 +1,99 @@ +from aioresponses import aioresponses +from fastapi import Depends + +from fastapi_security import FastAPISecurity, User +from fastapi_security.permissions import UserPermission + +from ..helpers.jwks import ( + dummy_audience, + dummy_jwks_response_data, + dummy_jwks_uri, + make_access_token, +) + + +def test_user_permission_repr(): + perm = UserPermission("inventory:list") + assert repr(perm) == "UserPermission(inventory:list)" + + +def test_user_permission_str(): + perm = UserPermission("inventory:list") + assert str(perm) == "inventory:list" + + +def test_that_missing_permission_results_in_403(app, client): + + security = FastAPISecurity() + + can_list = security.user_permission("users:list") # noqa + + @app.get("/users") + def get_user_list(user: User = Depends(security.user_holding(can_list))): + return [user] + + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) + + access_token = make_access_token(sub="test-user", permissions=[]) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + + resp = client.get("/users", headers={"Authorization": f"Bearer {access_token}"}) + assert resp.status_code == 403 + assert resp.json() == {"detail": "Missing required permission users:list"} + + +def test_that_assigned_permission_result_in_200(app, client): + + security = FastAPISecurity() + + can_list = security.user_permission("users:list") # noqa + + @app.get("/users") + def get_user_list(user: User = Depends(security.user_holding(can_list))): + return [user] + + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) + + access_token = make_access_token(sub="test-user", permissions=["users:list"]) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + + resp = client.get("/users", headers={"Authorization": f"Bearer {access_token}"}) + assert resp.status_code == 200 + (user1,) = resp.json() + assert user1["auth"]["subject"] == "test-user" + + +def test_that_user_must_have_all_permissions(app, client): + + security = FastAPISecurity() + + can_list = security.user_permission("users:list") # noqa + can_view = security.user_permission("users:view") # noqa + + @app.get("/users") + def get_user_list(user: User = Depends(security.user_holding(can_list, can_view))): + return [user] + + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) + + bad_token = make_access_token(sub="test-user", permissions=["users:list"]) + valid_token = make_access_token( + sub="JaneDoe", + permissions=["users:list", "users:view"], + ) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + + resp = client.get("/users", headers={"Authorization": f"Bearer {bad_token}"}) + assert resp.status_code == 403 + assert resp.json() == {"detail": "Missing required permission users:view"} + + resp = client.get("/users", headers={"Authorization": f"Bearer {valid_token}"}) + assert resp.status_code == 200 + (user1,) = resp.json() + assert user1["auth"]["subject"] == "JaneDoe" diff --git a/tests/integration/test_user_data.py b/tests/integration/test_user_data.py new file mode 100644 index 0000000..f565549 --- /dev/null +++ b/tests/integration/test_user_data.py @@ -0,0 +1,243 @@ +from aioresponses import aioresponses +from fastapi import Depends + +from fastapi_security import FastAPISecurity, User + +from ..helpers.jwks import ( + dummy_audience, + dummy_jwks_response_data, + dummy_jwks_uri, + make_access_token, +) +from ..helpers.oidc import dummy_oidc_url, dummy_userinfo_endpoint_url + + +def test_that_authenticated_user_auth_data_is_returned_as_expected(app, client): + + security = FastAPISecurity() + + @app.get("/users/me") + def get_user_info(user: User = Depends(security.authenticated_user_or_401)): + return user.without_access_token() + + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) + + access_token = make_access_token(sub="test-subject", scope=["email"]) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + + resp = client.get( + "/users/me", headers={"Authorization": f"Bearer {access_token}"} + ) + assert resp.status_code == 200 + data = resp.json()["auth"] + del data["expires_at"] + del data["issued_at"] + assert data == { + "audience": ["https://some-resource"], + "auth_method": "oauth2", + "issuer": "https://identity-provider/", + "permissions": [], + "scopes": ["email"], + "subject": "test-subject", + } + + +def test_that_user_dependency_works_authenticated_or_not(app, client): + + security = FastAPISecurity() + + @app.get("/users/me") + def get_user_info(user: User = Depends(security.user)): + return user.without_access_token() + + security.init_basic_auth([{"username": "JaneDoe", "password": "abc123"}]) + + # Anonymous + resp = client.get("/users/me") + assert resp.status_code == 200 + data = resp.json()["auth"] + del data["expires_at"] + del data["issued_at"] + assert data == { + "audience": [], + "auth_method": "none", + "issuer": None, + "permissions": [], + "scopes": [], + "subject": "anonymous", + } + + # Authenticated + resp = client.get("/users/me", auth=("JaneDoe", "abc123")) + assert resp.status_code == 200 + data = resp.json()["auth"] + del data["expires_at"] + del data["issued_at"] + assert data == { + "audience": [], + "auth_method": "basic_auth", + "issuer": None, + "permissions": [], + "scopes": [], + "subject": "JaneDoe", + } + + +def test_that_user_with_info_dependency_works_unauthenticated(app, client): + + security = FastAPISecurity() + + @app.get("/users/me") + def get_user_info(user: User = Depends(security.user_with_info)): + return user.without_access_token() + + security.init_basic_auth([{"username": "a", "password": "b"}]) + + resp = client.get("/users/me") + assert resp.status_code == 200 + info = resp.json()["info"] + assert info["nickname"] is None + + +def test_that_user_with_info_dependency_works_authenticated(app, client, caplog): + import logging + + caplog.set_level(logging.DEBUG) + security = FastAPISecurity() + + @app.get("/users/me") + def get_user_info(user: User = Depends(security.user_with_info)): + return user.without_access_token() + + security.init_oauth2_through_oidc(dummy_oidc_url, audiences=[dummy_audience]) + + with aioresponses() as mock: + mock.get( + dummy_oidc_url, + payload={ + "userinfo_endpoint": dummy_userinfo_endpoint_url, + "jwks_uri": dummy_jwks_uri, + }, + ) + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + mock.get(dummy_userinfo_endpoint_url, payload={"nickname": "jacobsvante"}) + token = make_access_token(sub="GMqBbybGfBQeR6NgCY4NyXKnpFzaaTAn@clients") + resp = client.get("/users/me", headers={"Authorization": f"Bearer {token}"}) + assert resp.status_code == 200 + data = resp.json() + info = data["info"] + assert info["nickname"] == "jacobsvante" + + +def test_that_authenticated_user_with_info_or_401_works_as_expected(app, client): + security = FastAPISecurity() + + @app.get("/users/me") + def get_user_info( + user: User = Depends(security.authenticated_user_with_info_or_401), + ): + return user.without_access_token() + + security.init_oauth2_through_oidc(dummy_oidc_url, audiences=[dummy_audience]) + security.init_basic_auth([{"username": "a", "password": "b"}]) + + with aioresponses() as mock: + mock.get( + dummy_oidc_url, + payload={ + "userinfo_endpoint": dummy_userinfo_endpoint_url, + "jwks_uri": dummy_jwks_uri, + }, + ) + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + mock.get(dummy_userinfo_endpoint_url, payload={"nickname": "jacobsvante"}) + token = make_access_token(sub="GMqBbybGfBQeR6NgCY4NyXKnpFzaaTAn@clients") + resp = client.get("/users/me", headers={"Authorization": f"Bearer {token}"}) + assert resp.status_code == 200 + info = resp.json()["info"] + assert info["nickname"] == "jacobsvante" + + # Basic auth + resp = client.get("/users/me", auth=("a", "b")) + assert resp.status_code == 200 + info = resp.json()["info"] + assert info["nickname"] is None + + # Unauthenticated + resp = client.get("/users/me") + assert resp.status_code == 401 + assert resp.json() == {"detail": "Could not validate credentials"} + + +def test_that_existing_permissions_are_added(app, client): + + security = FastAPISecurity() + + permission = security.user_permission("users:list") # noqa + + @app.get("/users/me") + def get_user_info(user: User = Depends(security.authenticated_user_or_401)): + return user.without_access_token() + + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) + + access_token = make_access_token( + sub="test-subject", + permissions=["users:list"], + ) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + + resp = client.get( + "/users/me", headers={"Authorization": f"Bearer {access_token}"} + ) + assert resp.status_code == 200 + data = resp.json()["auth"] + del data["expires_at"] + del data["issued_at"] + assert data == { + "audience": ["https://some-resource"], + "auth_method": "oauth2", + "issuer": "https://identity-provider/", + "permissions": ["users:list"], + "scopes": [], + "subject": "test-subject", + } + + +def test_that_nonexisting_permissions_are_ignored(app, client): + + security = FastAPISecurity() + + @app.get("/users/me") + def get_user_info(user: User = Depends(security.authenticated_user_or_401)): + return user.without_access_token() + + security.init_oauth2_through_jwks(dummy_jwks_uri, audiences=[dummy_audience]) + + access_token = make_access_token( + sub="test-subject", + permissions=["users:list"], + ) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + + resp = client.get( + "/users/me", headers={"Authorization": f"Bearer {access_token}"} + ) + assert resp.status_code == 200 + data = resp.json()["auth"] + del data["expires_at"] + del data["issued_at"] + assert data == { + "audience": ["https://some-resource"], + "auth_method": "oauth2", + "issuer": "https://identity-provider/", + "permissions": [], + "scopes": [], + "subject": "test-subject", + } diff --git a/tests/test_basic_auth.py b/tests/test_basic_auth.py deleted file mode 100644 index 432dee5..0000000 --- a/tests/test_basic_auth.py +++ /dev/null @@ -1,37 +0,0 @@ -from fastapi import Depends - -from fastapi_security import FastAPISecurity, User - - -def test_that_basic_auth_rejects_incorrect_credentials(app, client): - security = FastAPISecurity() - - @app.get("/") - def get_products(user: User = Depends(security.authenticated_user_or_401)): - return [] - - credentials = [{"username": "user", "password": "pass"}] - security.init(app, basic_auth_credentials=credentials) - - resp = client.get("/") - assert resp.status_code == 401 - - resp = client.get("/", auth=("user", "")) - assert resp.status_code == 401 - - resp = client.get("/", auth=("", "pass")) - assert resp.status_code == 401 - - -def test_that_basic_auth_accepts_correct_credentials(app, client): - security = FastAPISecurity() - - @app.get("/") - def get_products(user: User = Depends(security.authenticated_user_or_401)): - return [] - - credentials = [{"username": "user", "password": "pass"}] - security.init(app, basic_auth_credentials=credentials) - - resp = client.get("/", auth=("user", "pass")) - assert resp.status_code == 200 diff --git a/tests/test_permissions.py b/tests/test_permissions.py deleted file mode 100644 index 3cefe1c..0000000 --- a/tests/test_permissions.py +++ /dev/null @@ -1,60 +0,0 @@ -from aioresponses import aioresponses -from fastapi import Depends - -from fastapi_security import FastAPISecurity, User, UserPermission - -from .jwks_helpers import ( - dummy_audience, - dummy_jwks_response_data, - dummy_jwks_url, - make_access_token, -) - - -def test_that_missing_permission_results_in_403(app, client): - - security = FastAPISecurity() - - can_list = UserPermission("users:list") # noqa - - @app.get("/users/registry") - def get_user_list(user: User = Depends(security.user_with_permissions(can_list))): - return [user] - - security.init(app, jwks_url=dummy_jwks_url, audiences=[dummy_audience]) - - access_token = make_access_token(sub="test-user", permissions=[]) - - with aioresponses() as mock: - mock.get(dummy_jwks_url, payload=dummy_jwks_response_data) - - resp = client.get( - "/users/registry", headers={"Authorization": f"Bearer {access_token}"} - ) - assert resp.status_code == 403 - assert resp.json() == {"detail": "Missing required permission users:list"} - - -def test_that_assigned_permission_result_in_200(app, client): - - security = FastAPISecurity() - - can_list = UserPermission("users:list") # noqa - - @app.get("/users/registry") - def get_user_list(user: User = Depends(security.user_with_permissions(can_list))): - return [user] - - security.init(app, jwks_url=dummy_jwks_url, audiences=[dummy_audience]) - - access_token = make_access_token(sub="test-user", permissions=["users:list"]) - - with aioresponses() as mock: - mock.get(dummy_jwks_url, payload=dummy_jwks_response_data) - - resp = client.get( - "/users/registry", headers={"Authorization": f"Bearer {access_token}"} - ) - assert resp.status_code == 200 - (user1,) = resp.json() - assert user1["auth"]["subject"] == "test-user" diff --git a/tests/test_user_data.py b/tests/test_user_data.py deleted file mode 100644 index a536bdc..0000000 --- a/tests/test_user_data.py +++ /dev/null @@ -1,115 +0,0 @@ -from aioresponses import aioresponses -from fastapi import Depends - -from fastapi_security import FastAPISecurity, User, UserPermission - -from .jwks_helpers import ( - dummy_audience, - dummy_jwks_response_data, - dummy_jwks_url, - make_access_token, -) - - -def test_that_user_auth_data_is_returned_as_expected(app, client): - - security = FastAPISecurity() - - @app.get("/users/me") - def get_user_info(user: User = Depends(security.authenticated_user_or_401)): - return user.without_access_token().without_extra() - - security.init(app, jwks_url=dummy_jwks_url, audiences=[dummy_audience]) - - access_token = make_access_token(sub="test-subject", scope=["email"]) - - with aioresponses() as mock: - mock.get(dummy_jwks_url, payload=dummy_jwks_response_data) - - resp = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) - assert resp.status_code == 200 - data = resp.json()["auth"] - del data["expires_at"] - del data["issued_at"] - assert data == { - "audience": ["https://some-resource"], - "auth_method": "oauth2", - "issuer": "https://identity-provider/", - "permissions": [], - "scopes": ["email"], - "subject": "test-subject", - } - - -def test_that_existing_permissions_are_added(app, client): - - security = FastAPISecurity() - - permission = UserPermission("users:list") # noqa - - @app.get("/users/me") - def get_user_info(user: User = Depends(security.authenticated_user_or_401)): - return user.without_access_token().without_extra() - - security.init(app, jwks_url=dummy_jwks_url, audiences=[dummy_audience]) - - access_token = make_access_token( - sub="test-subject", - permissions=["users:list"], - ) - - with aioresponses() as mock: - mock.get(dummy_jwks_url, payload=dummy_jwks_response_data) - - resp = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) - assert resp.status_code == 200 - data = resp.json()["auth"] - del data["expires_at"] - del data["issued_at"] - assert data == { - "audience": ["https://some-resource"], - "auth_method": "oauth2", - "issuer": "https://identity-provider/", - "permissions": ["users:list"], - "scopes": [], - "subject": "test-subject", - } - - -def test_that_nonexisting_permissions_are_ignored(app, client): - - security = FastAPISecurity() - - @app.get("/users/me") - def get_user_info(user: User = Depends(security.authenticated_user_or_401)): - return user.without_access_token().without_extra() - - security.init(app, jwks_url=dummy_jwks_url, audiences=[dummy_audience]) - - access_token = make_access_token( - sub="test-subject", - permissions=["users:list"], - ) - - with aioresponses() as mock: - mock.get(dummy_jwks_url, payload=dummy_jwks_response_data) - - resp = client.get( - "/users/me", headers={"Authorization": f"Bearer {access_token}"} - ) - assert resp.status_code == 200 - data = resp.json()["auth"] - del data["expires_at"] - del data["issued_at"] - assert data == { - "audience": ["https://some-resource"], - "auth_method": "oauth2", - "issuer": "https://identity-provider/", - "permissions": [], - "scopes": [], - "subject": "test-subject", - } diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_entities.py b/tests/unit/test_entities.py new file mode 100644 index 0000000..8ba0602 --- /dev/null +++ b/tests/unit/test_entities.py @@ -0,0 +1,93 @@ +import datetime + +from fastapi_security.entities import ( + AuthMethod, + JwtAccessToken, + User, + UserAuth, + UserInfo, +) + + +def test_make_dummy_user_info(): + dummy = UserInfo.make_dummy() + assert dummy.dict() == { + "given_name": None, + "family_name": None, + "nickname": None, + "name": None, + "picture": None, + "locale": None, + "updated_at": None, + "email": None, + "email_verified": None, + } + + +def test_anonymous_user_auth(): + anon = UserAuth.make_anonymous() + assert anon.is_anonymous() + assert anon.dict() == { + "subject": "anonymous", + "auth_method": AuthMethod.none, + "issuer": None, + "audience": [], + "issued_at": None, + "expires_at": None, + "scopes": [], + "permissions": [], + "access_token": None, + } + + +def test_user_auth_get_user_id(): + u = UserAuth(subject="johndoe", auth_method="basic_auth") + assert u.get_user_id() == "johndoe" + + +def test_that_user_auth_accepts_client_credentials_grant_type(): + jwt_token = JwtAccessToken( + iss="a", + sub="johndoe", + aud="a", + iat="2021-03-26 11:25", + exp="2021-03-27 11:25", + raw="", + gty="client-credentials", + ) + assert jwt_token.is_client_credentials() + + auth = UserAuth.from_jwt_access_token(jwt_token) + assert auth.is_oauth2() + + +def test_that_user_methods_work_correctly(): + jwt_token = JwtAccessToken( + iss="a", + sub="johndoe", + aud="a", + iat="2021-03-26 11:25", + exp="2021-03-27 11:25", + raw="", + permissions=["products:create"], + ) + auth = UserAuth.from_jwt_access_token(jwt_token) + user = User(auth=auth) + assert user.permissions == ["products:create"] + # NOTE: Expiry etc is validated in a higher layer + assert user.is_authenticated() + assert not user.is_anonymous() + assert user.get_user_id() == "johndoe" + assert user.has_permission("products:create") + + assert user.dict()["auth"] == { + "access_token": "", + "audience": ["a"], + "auth_method": AuthMethod.oauth2, + "expires_at": datetime.datetime(2021, 3, 27, 11, 25), + "issued_at": datetime.datetime(2021, 3, 26, 11, 25), + "issuer": "a", + "permissions": ["products:create"], + "scopes": [], + "subject": "johndoe", + } diff --git a/tests/unit/test_oauth2.py b/tests/unit/test_oauth2.py new file mode 100644 index 0000000..8f4a145 --- /dev/null +++ b/tests/unit/test_oauth2.py @@ -0,0 +1,174 @@ +import logging + +import aiohttp +import pytest +from aioresponses import aioresponses + +from fastapi_security.entities import JwtAccessToken +from fastapi_security.oauth2 import Oauth2JwtAccessTokenValidator + +from ..helpers.jwks import ( + dummy_audience, + dummy_jwks_response_data, + dummy_jwks_uri, + dummy_jwt_headers, + make_access_token, +) + +pytestmark = pytest.mark.asyncio + + +async def test_that_jwt_cant_be_validated_when_uninitialized(caplog): + caplog.set_level(logging.INFO) + validator = Oauth2JwtAccessTokenValidator() + # validator.init(dummy_jwks_uri, [dummy_audience]) + parsed = await validator.parse("abc") + assert "JWT Access Token validator is not set up!" in caplog.text + assert parsed is None + + +@pytest.mark.parametrize("empty_data", ((None, "", 0))) +async def test_that_empty_jwt_is_invalid(caplog, empty_data): + caplog.set_level(logging.DEBUG) + validator = Oauth2JwtAccessTokenValidator() + validator.init(dummy_jwks_uri, [dummy_audience]) + parsed = await validator.parse(empty_data) + assert "No JWT token provided" in caplog.text + assert parsed is None + + +async def test_that_unparseable_token_is_invalid(caplog): + caplog.set_level(logging.DEBUG) + validator = Oauth2JwtAccessTokenValidator() + validator.init(dummy_jwks_uri, [dummy_audience]) + parsed = await validator.parse("badDATA") + assert ( + "Decoding unverified JWT token failed with error: DecodeError('Not enough segments')" + in caplog.text + ) + assert parsed is None + + +async def test_that_missing_kid_field_is_invalid(caplog): + caplog.set_level(logging.DEBUG) + validator = Oauth2JwtAccessTokenValidator() + validator.init(dummy_jwks_uri, [dummy_audience]) + token = make_access_token(sub="johndoe", headers={"alg": "RS256", "typ": "JWT"}) + parsed = await validator.parse(token) + assert "No `kid` found in JWT token" in caplog.text + assert parsed is None + + +async def test_that_mismatching_kid_field_fails(caplog): + caplog.set_level(logging.DEBUG) + validator = Oauth2JwtAccessTokenValidator() + validator.init(dummy_jwks_uri, [dummy_audience]) + token = make_access_token( + sub="johndoe", + headers={"alg": "RS256", "typ": "JWT", "kid": "someother"}, + ) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + parsed = await validator.parse(token) + + assert "No matching `kid` for JWT token" in caplog.text + assert parsed is None + + +async def test_that_hs256_doesnt_work(caplog): + caplog.set_level(logging.DEBUG) + validator = Oauth2JwtAccessTokenValidator() + validator.init(dummy_jwks_uri, [dummy_audience]) + token = make_access_token( + sub="johndoe", + headers={**dummy_jwt_headers, "alg": "HS256"}, + ) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + parsed = await validator.parse(token) + + assert ( + "Decoding verified JWT token failed with error: InvalidAlgorithmError('The specified alg value is not allowed')" + in caplog.text + ) + assert parsed is None + + +async def test_that_missing_audience_fails(caplog): + caplog.set_level(logging.DEBUG) + validator = Oauth2JwtAccessTokenValidator() + validator.init(dummy_jwks_uri, [dummy_audience]) + token = make_access_token( + sub="johndoe", + delete_fields=["aud"], + ) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + parsed = await validator.parse(token) + + assert ( + "Decoding verified JWT token failed with error: MissingRequiredClaimError('aud')" + in caplog.text + ) + assert parsed is None + + +async def test_that_missing_expiry_date_fails(caplog): + caplog.set_level(logging.DEBUG) + validator = Oauth2JwtAccessTokenValidator() + validator.init(dummy_jwks_uri, [dummy_audience]) + token = make_access_token( + sub="johndoe", + delete_fields=["exp"], + ) + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + parsed = await validator.parse(token) + + assert ( + "Failed to parse JWT token with 1 validation error for JwtAccessToken\nexp\n field required (type=value_error.missing)" + in caplog.text + ) + assert parsed is None + + +async def test_that_initial_failure_to_get_jwks_kid_data_raises_exception(): + validator = Oauth2JwtAccessTokenValidator() + validator.init(dummy_jwks_uri, [dummy_audience]) + + token = make_access_token(sub="johndoe") + + with pytest.raises( + aiohttp.client_exceptions.ClientConnectorError, + match="Cannot connect to host identity-provider:443", + ): + await validator.parse(token) + + +async def test_that_subsequent_failure_to_fetch_jwks_kid_data_is_handled(caplog): + validator = Oauth2JwtAccessTokenValidator() + validator.init(dummy_jwks_uri, [dummy_audience]) + + token = make_access_token(sub="johndoe") + + with aioresponses() as mock: + mock.get(dummy_jwks_uri, payload=dummy_jwks_response_data) + await validator.parse(token) + + caplog.set_level(logging.INFO) + + # NOTE: Reaching into the internals to trigger JWKS kid data refresh + validator._jwks_cached_at = -3600 + # NOTE: Not mocking JWKS endpoint, which would cause a "Cannot connect" error on + # the first try. + parsed = await validator.parse(token) + + assert isinstance(parsed, JwtAccessToken) + assert ( + "Failed to refresh JWKS kid mapping, re-using old data. Exception was: ClientConnectorError" + in caplog.text + ) diff --git a/tests/unit/test_oidc.py b/tests/unit/test_oidc.py new file mode 100644 index 0000000..6287604 --- /dev/null +++ b/tests/unit/test_oidc.py @@ -0,0 +1,84 @@ +import logging + +import aiohttp +import pytest +from aioresponses import aioresponses + +from fastapi_security.oidc import OpenIdConnectDiscovery + +from ..helpers.jwks import make_access_token +from ..helpers.oidc import dummy_oidc_url, dummy_userinfo_endpoint_url + +pytestmark = pytest.mark.asyncio + + +async def test_that_getting_user_info_doesnt_work_uninitialized(caplog): + caplog.set_level(logging.INFO) + token = make_access_token(sub="janedoe") + oidc = OpenIdConnectDiscovery() + user_info = await oidc.get_user_info(token) + assert user_info is None + assert "OpenID Connect discovery URL is not set up!" in caplog.text + + +async def test_that_getting_user_info_with_empty_access_token_doesnt_work(caplog): + caplog.set_level(logging.DEBUG) + oidc = OpenIdConnectDiscovery() + oidc.init(dummy_oidc_url) + user_info = await oidc.get_user_info("") + assert user_info is None + assert "No access token provided" in caplog.text + + +async def test_that_dummy_user_info_is_returned_when_endpoint_returns_non_200(caplog): + caplog.set_level(logging.DEBUG) + oidc = OpenIdConnectDiscovery() + oidc.init(dummy_oidc_url) + token = make_access_token(sub="JaneDoe") + + with aioresponses() as mock: + mock.get( + dummy_oidc_url, + payload={"userinfo_endpoint": dummy_userinfo_endpoint_url}, + ) + mock.get(dummy_userinfo_endpoint_url, status=503) + + user_info = await oidc.get_user_info(token) + + assert all(v is None for v in user_info.dict().values()) + + +async def test_that_initial_failure_to_fetch_discovery_data_raises_exception(): + oidc = OpenIdConnectDiscovery() + oidc.init(dummy_oidc_url) + + with pytest.raises( + aiohttp.client_exceptions.ClientConnectorError, + match="Cannot connect to host oidc-provider:443", + ): + await oidc.get_discovery_data() + + +async def test_that_subsequent_failure_to_fetch_discovery_data_is_handled(caplog): + oidc = OpenIdConnectDiscovery() + oidc.init(dummy_oidc_url) + + with aioresponses() as mock: + mock.get( + dummy_oidc_url, payload={"userinfo_endpoint": dummy_userinfo_endpoint_url} + ) + await oidc.get_discovery_data() + + caplog.set_level(logging.INFO) + + # NOTE: Reaching into the internals to trigger JWKS kid data refresh + oidc._discovery_data_cached_at = -3600 + # NOTE: Not mocking JWKS endpoint, which would cause a "Cannot connect" error on + # the first try. + parsed = await oidc.get_discovery_data() + + assert parsed == {"userinfo_endpoint": dummy_userinfo_endpoint_url} + assert ( + "Failed to refresh OIDC discovery data, re-using old data. Exception was: ClientConnectorError" + in caplog.text + )