Skip to content

Commit

Permalink
Merge pull request #484 from SUNET/lundberg_scimapi_interaction_auth
Browse files Browse the repository at this point in the history
scimapi interaction auth
  • Loading branch information
helylle authored Nov 8, 2023
2 parents b632406 + 486a10b commit ee26b4b
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 33 deletions.
5 changes: 5 additions & 0 deletions src/eduid/scimapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ class ScimApiConfig(RootConfig, LoggingConfigMixin, AWSMixin):
scope_sudo: dict[ScopeName, set[ScopeName]] = Field(default={})
# The expected value of the authn JWT claims['requested_access']['type']
requested_access_type: Optional[str] = "scim-api"
# required saml assurance level for authentications with interaction auth_source
required_saml_assurance_level: list[str] = Field(default=["http://www.swamid.se/policy/assurance/al3"])
# group name to match saml entitlement for authorization
account_manager_default_group: str = "Account Managers"
account_manager_group_mapping: dict[DataOwnerName, str] = Field(default={})
# Invite config
invite_url: str = ""
invite_expire: int = 180 * 86400 # 180 days
Expand Down
124 changes: 119 additions & 5 deletions src/eduid/scimapi/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import logging
import re
from copy import copy
from enum import Enum
from typing import Any, Mapping, Optional

from fastapi import Request, Response
from jwcrypto import jwt
from jwcrypto.common import JWException
from pydantic import BaseModel, Field, StrictInt, ValidationError, validator
from pydantic import BaseModel, Field, StrictInt, ValidationError, root_validator, validator
from starlette.datastructures import URL
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import Message
Expand All @@ -17,13 +18,31 @@
from eduid.scimapi.context import Context
from eduid.scimapi.context_request import ContextRequestMixin
from eduid.scimapi.exceptions import Unauthorized, http_error_detail_handler
from eduid.userdb.scimapi import ScimApiGroupDB

logger = logging.getLogger(__name__)


class AuthSource(str, Enum):
INTERACTION = "interaction"
CONFIG = "config"
MDQ = "mdq"
TLSFED = "tlsfed"


class SudoAccess(BaseModel):
type: str
scope: ScopeName


class AuthenticationError(Exception):
pass


class AuthorizationError(Exception):
pass


class RequestedAccessDenied(Exception):
"""Break out of get_data_owner when requested access (in the token) is not allowed"""

Expand All @@ -37,8 +56,18 @@ class AuthnBearerToken(BaseModel):

scim_config: ScimApiConfig # must be listed first, used in validators
version: StrictInt
auth_source: AuthSource
requested_access: list[SudoAccess] = Field(default=[])
scopes: set[ScopeName] = Field(default=set())
# saml interaction claims
saml_issuer: Optional[str] = None
saml_assurance: Optional[list[str]] = None
saml_entitlement: Optional[list[str]] = None
saml_eppn: Optional[str] = None
saml_unique_id: Optional[str] = None

# class Config:
# validate_assignment = True

def __str__(self):
return f"<{self.__class__.__name__}: scopes={self.scopes}, requested_access={self.requested_access}>"
Expand All @@ -49,6 +78,13 @@ def validate_version(cls, v: int) -> int:
raise ValueError("Unknown version")
return v

@root_validator(pre=True)
def set_scopes_from_saml_data(cls, values: dict[str, Any]):
# Get scope from saml identifier if the auth source is interaction and set it as scopes
if values.get("auth_source") == AuthSource.INTERACTION.value:
values["scopes"] = cls._get_scope_from_saml_data(values=values)
return values

@validator("scopes")
def validate_scopes(cls, v: set[ScopeName], values: Mapping[str, Any]) -> set[ScopeName]:
config = values.get("scim_config")
Expand All @@ -71,7 +107,67 @@ def validate_requested_access(cls, v: list[SudoAccess], values: Mapping[str, Any
new_access += [this]
return new_access

def get_data_owner(self, logger: logging.Logger) -> Optional[DataOwnerName]:
@staticmethod
def _get_scope_from_saml_data(values: Mapping[str, Any]) -> list[ScopeName]:
saml_identifier = values.get("saml_eppn") or values.get("saml_unique_id")
if not saml_identifier:
return []
try:
scope = ScopeName(saml_identifier.split("@")[1])
except IndexError:
return []
logger.info(f"Scope from saml data: {scope}")
return [scope]

def validate_auth_source(self) -> None:
"""
Check if the auth source is any of the one we know of. If the auth source is config, mdq or tlsfed we
can just let it through. If the auth source is interaction we need to check the saml data to make sure
the user is allowed access to the data owner.
"""
if self.auth_source in [AuthSource.CONFIG, AuthSource.MDQ, AuthSource.TLSFED]:
logger.info(f"{self.auth_source} is a trusted auth source")
return

if self.auth_source == AuthSource.INTERACTION:
assurances = self.saml_assurance or []
# validate that the authentication meets the required assurance level
for assurance_level in self.scim_config.required_saml_assurance_level:
if assurance_level in assurances:
logger.info(f"Allowed assurance level {assurance_level} is in saml data: {assurances}")
return
raise AuthenticationError(
f"Asserted SAML assurance level(s) ({assurances}) not in"
f"allow-list: {self.scim_config.required_saml_assurance_level}"
)

raise AuthenticationError(f"Unsupported authentication source: {self.auth_source}")

def validate_saml_entitlements(self, data_owner: DataOwnerName, groupdb: Optional[ScimApiGroupDB] = None) -> None:
if groupdb is None:
raise AuthenticationError("No groupdb provided, cannot validate saml entitlements.")

default_name = self.scim_config.account_manager_default_group
account_manager_group_name = self.scim_config.account_manager_group_mapping.get(data_owner, default_name)
logger.debug(f"Checking for account manager group called {account_manager_group_name}")

account_manager_group = groupdb.get_group_by_display_name(display_name=account_manager_group_name)
if account_manager_group is None:
raise AuthenticationError('No "Account Managers" group found for data owner')
logger.debug(f"Found group {account_manager_group_name} with id {account_manager_group.graph.identifier}")

# TODO: create a helper function to do this for all places where we do this dance in the repo
# create the expected saml group id
saml_group_id = f"{groupdb.graphdb.scope}:group:{account_manager_group.graph.identifier}#eduid-iam"
# match against users entitlements
entitlements = self.saml_entitlement or []
if saml_group_id in entitlements:
logger.debug(f"{saml_group_id} in {entitlements}")
return
logger.error(f"{saml_group_id} NOT in {entitlements}")
raise AuthorizationError(f"Not authorized: {saml_group_id} not in saml entitlements")

def get_data_owner(self) -> Optional[DataOwnerName]:
"""
Get the data owner to use.
Expand All @@ -94,7 +190,7 @@ def get_data_owner(self, logger: logging.Logger) -> Optional[DataOwnerName]:
requested_access: [{'type': 'scim-api', 'scope': 'example.edu'}]}
"""

allowed_scopes = self._get_allowed_scopes(self.scim_config, logger)
allowed_scopes = self._get_allowed_scopes(self.scim_config)
logger.debug(f"Request {self}, allowed scopes: {allowed_scopes}")

# only support one requested access at a time for now and do not fall back to simple scope check if
Expand Down Expand Up @@ -126,7 +222,7 @@ def get_data_owner(self, logger: logging.Logger) -> Optional[DataOwnerName]:

return None

def _get_allowed_scopes(self, config: ScimApiConfig, logger: logging.Logger) -> set[ScopeName]:
def _get_allowed_scopes(self, config: ScimApiConfig) -> set[ScopeName]:
"""
Make a set of all the allowed scopes for the requester.
Expand Down Expand Up @@ -260,7 +356,15 @@ async def dispatch(self, req: Request, call_next) -> Response:
return await http_error_detail_handler(req=req, exc=Unauthorized(detail="Bearer token error"))

try:
data_owner = token.get_data_owner(self.context.logger)
token.validate_auth_source()
except AuthenticationError as exc:
self.context.logger.error(f"Access denied: {exc}")
return await http_error_detail_handler(
req=req, exc=Unauthorized(detail="Authentication source or assurance level invalid")
)

try:
data_owner = token.get_data_owner()
except RequestedAccessDenied as exc:
self.context.logger.error(f"Access denied: {exc}")
return await http_error_detail_handler(
Expand All @@ -278,4 +382,14 @@ async def dispatch(self, req: Request, call_next) -> Response:
req.context.invitedb = self.context.get_invitedb(data_owner)
req.context.eventdb = self.context.get_eventdb(data_owner)

# check authorization for interaction authentications
try:
if token.auth_source == AuthSource.INTERACTION:
token.validate_saml_entitlements(data_owner=data_owner, groupdb=req.context.groupdb)
except AuthorizationError as exc:
self.context.logger.error(f"Access denied: {exc}")
return await http_error_detail_handler(
req=req, exc=Unauthorized(detail="Missing correct entitlement in saml data")
)

return await call_next(req)
2 changes: 2 additions & 0 deletions src/eduid/scimapi/routers/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from eduid.scimapi.api_router import APIRouter
from eduid.scimapi.context_request import ContextRequest
from eduid.scimapi.exceptions import ErrorDetail, NotFound, Unauthorized
from eduid.scimapi.middleware import AuthSource
from eduid.scimapi.models.login import TokenRequest

login_router = APIRouter(
Expand Down Expand Up @@ -33,6 +34,7 @@ async def get_token(req: ContextRequest, resp: Response, token_req: TokenRequest
"exp": expire.timestamp(),
"scopes": [token_req.data_owner],
"version": 1,
"auth_source": AuthSource.CONFIG,
}
token = jwt.JWT(header={"alg": "ES256"}, claims=claims)
token.make_signed_token(signing_key)
Expand Down
18 changes: 14 additions & 4 deletions src/eduid/scimapi/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from eduid.common.config.parsers import load_config
from eduid.common.models.scim_base import SCIMSchema
from eduid.graphdb.groupdb import User as GraphUser
from eduid.graphdb.testing import Neo4jTemporaryInstance
from eduid.queue.db.message import MessageDB
from eduid.scimapi.app import init_api
Expand Down Expand Up @@ -83,10 +84,6 @@ def setUpClass(cls) -> None:
)
super().setUpClass()

def tearDown(self):
super().tearDown()
self.neo4j_instance.purge_db()


class ScimApiTestCase(MongoNeoTestCase):
"""Base test case providing the real API"""
Expand All @@ -107,6 +104,7 @@ def setUp(self) -> None:
# TODO: more tests for scoped groups when that is implemented
self.data_owner = DataOwnerName("eduid.se")
self.userdb = self.context.get_userdb(self.data_owner)
self.groupdb = self.context.get_groupdb(self.data_owner)
self.invitedb = self.context.get_invitedb(self.data_owner)
self.signup_invitedb = SignupInviteDB(db_uri=config.mongo_uri)
self.messagedb = MessageDB(db_uri=config.mongo_uri)
Expand Down Expand Up @@ -143,6 +141,15 @@ def add_user(
self.userdb.save(user)
return self.userdb.get_user_by_scim_id(scim_id=identifier)

def add_group_with_member(
self, group_identifier: str, display_name: str, user_identifier: str
) -> Optional[ScimApiGroup]:
group = ScimApiGroup(scim_id=uuid.UUID(group_identifier), display_name=display_name)
group.add_member(GraphUser(identifier=user_identifier, display_name="Test Member 1"))
assert self.groupdb
self.groupdb.save(group)
return self.groupdb.get_group_by_scim_id(scim_id=group_identifier)

def tearDown(self):
super().tearDown()
if self.userdb:
Expand All @@ -155,6 +162,9 @@ def tearDown(self):
self.signup_invitedb._drop_whole_collection()
if self.messagedb:
self.messagedb._drop_whole_collection()
if self.groupdb:
self.groupdb._drop_whole_collection()
self.neo4j_instance.purge_db()

def _assertScimError(
self,
Expand Down
Loading

0 comments on commit ee26b4b

Please sign in to comment.