From b125321154873ccd4f76b9a8a0fc4050efdf296a Mon Sep 17 00:00:00 2001 From: Guerdon Mukama Date: Mon, 7 Oct 2024 11:15:28 +1100 Subject: [PATCH] Code revision and refactoring --- .gitignore | 3 + fence/blueprints/login/base.py | 86 +++++++--- fence/config-default.yaml | 17 +- fence/error_handler.py | 6 +- fence/job/access_token_updater.py | 37 ++--- fence/resources/openid/idp_oauth2.py | 211 +++++++++++++++++-------- tests/conftest.py | 22 +++ tests/job/test_access_token_updater.py | 56 +++++-- tests/login/test_base.py | 3 - tests/login/test_idp_oauth2.py | 182 +++++++++++---------- tests/test-fence-config.yaml | 17 +- 11 files changed, 416 insertions(+), 224 deletions(-) diff --git a/.gitignore b/.gitignore index 4a76c3a3e..7e18527a4 100644 --- a/.gitignore +++ b/.gitignore @@ -108,3 +108,6 @@ tests/resources/keys/*.pem .DS_Store .vscode .idea + +# snyk +.dccache \ No newline at end of file diff --git a/fence/blueprints/login/base.py b/fence/blueprints/login/base.py index fe07cc517..044407174 100644 --- a/fence/blueprints/login/base.py +++ b/fence/blueprints/login/base.py @@ -12,6 +12,7 @@ from fence.config import config from fence.errors import UserError from fence.metrics import metrics + logger = get_logger(__name__) @@ -96,8 +97,11 @@ def __init__( "OPENID_CONNECT" ].get(self.idp_name, {}) self.app = app - self.check_groups = config.get("CHECK_GROUPS", False) - self.app = app if app is not None else flask.current_app + # this attribute is only applicable to some OAuth clients + # (e.g., not all clients need read_authz_groups_from_tokens) + self.is_read_authz_groups_from_tokens_enabled = getattr( + self.client, "read_authz_groups_from_tokens", False + ) def get(self): # Check if user granted access @@ -145,9 +149,15 @@ def get(self): if expires is None: expires = int(time.time()) + config["REFRESH_TOKEN_EXPIRES_IN"] - # # Store refresh token in db - if self.check_groups: - self.client.store_refresh_token(flask.g.user,refresh_token,expires) + # Store refresh token in db + if self.is_read_authz_groups_from_tokens_enabled: + # Ensure flask.g.user exists to avoid a potential AttributeError + if getattr(flask.g, "user", None): + self.client.store_refresh_token(flask.g.user, refresh_token, expires) + else: + self.logger.error( + "User information is missing from flask.g; cannot store refresh token." + ) self.post_login( user=flask.g.user, @@ -157,22 +167,50 @@ def get(self): return resp - # see if the refresh token is a JWT. if it is decode to get the exp. we do not care about signatures, the - # reason is that the refresh token is checked by the IDP, not us, thus we don't have the key in most circumstances - # Also check exp from introspect results def extract_exp(self, refresh_token): + """ + Extract the expiration time (exp) from a refresh token. + + This function attempts to extract the `exp` (expiration time) from a given refresh token using + three methods: + + 1. Using PyJWT to decode the token (without signature verification). + 2. Introspecting the token (if supported by the identity provider). + 3. Manually base64 decoding the token's payload (if it's a JWT). + + Disclaimer: + ------------ + This function assumes that the refresh token is valid and does not perform any JWT validation. + For any JWT coming from an OpenID Connect (OIDC) provider, validation should be done using the + public keys provided by the IdP (from the JWKS endpoint) before using this function to extract + the expiration time (`exp`). Without validation, the token's integrity and authenticity cannot + be guaranteed, which may expose your system to security risks. + + Ensure validation is handled prior to calling this function, especially in any public or + production-facing contexts. + + Parameters: + ------------ + refresh_token: str + The JWT refresh token to extract the expiration from. + + Returns: + --------- + int or None: + The expiration time (exp) in seconds since the epoch, or None if extraction fails. + """ + # Method 1: PyJWT try: # Skipping keys since we're not verifying the signature decoded_refresh_token = jwt.decode( refresh_token, - options= - { + options={ "verify_aud": False, "verify_at_hash": False, - "verify_signature": False + "verify_signature": False, }, - algorithms=["RS256", "HS512"] + algorithms=["RS256", "HS512"], ) exp = decoded_refresh_token.get("exp") @@ -194,9 +232,9 @@ def extract_exp(self, refresh_token): # Method 3: Manual base64 decoding try: # Assuming the token is a JWT (header.payload.signature) - payload_encoded = refresh_token.split('.')[1] + payload_encoded = refresh_token.split(".")[1] # Add necessary padding for base64 decoding - payload_encoded += '=' * (4 - len(payload_encoded) % 4) + payload_encoded += "=" * (4 - len(payload_encoded) % 4) payload_decoded = base64.urlsafe_b64decode(payload_encoded) payload_json = json.loads(payload_decoded) exp = payload_json.get("exp") @@ -212,16 +250,16 @@ def extract_exp(self, refresh_token): def introspect_token(self, token): try: - introspect_endpoint = self.client.get_value_from_discovery_doc("introspection_endpoint", "") + introspect_endpoint = self.client.get_value_from_discovery_doc( + "introspection_endpoint", "" + ) # Headers and payload for the introspection request - headers = { - "Content-Type": "application/x-www-form-urlencoded" - } + headers = {"Content-Type": "application/x-www-form-urlencoded"} data = { "token": token, - "client_id": self.client.settings.get("client_id"), - "client_secret": self.client.settings.get("client_secret") + "client_id": self.client.client_id, + "client_secret": self.client.client_secret, } response = requests.post(introspect_endpoint, headers=headers, data=data) @@ -247,8 +285,12 @@ def post_login(self, user=None, token_result=None, **kwargs): client_id=flask.session.get("client_id"), ) - if self.check_groups: - self.client.update_user_authorization(user=user,pkey_cache=None,db_session=None,idp_name=self.idp_name) + # this attribute is only applicable to some OAuth clients + # (e.g., not all clients need read_authz_groups_from_tokens) + if self.is_read_authz_groups_from_tokens_enabled: + self.client.update_user_authorization( + user=user, pkey_cache=None, db_session=None, idp_name=self.idp_name + ) if token_result: username = token_result.get(self.username_field) diff --git a/fence/config-default.yaml b/fence/config-default.yaml index b1474fc1f..f3b62f237 100755 --- a/fence/config-default.yaml +++ b/fence/config-default.yaml @@ -94,7 +94,7 @@ DB_MIGRATION_POSTGRES_LOCK_KEY: 100 # - WARNING: Be careful changing the *_ALLOWED_SCOPES as you can break basic # and optional functionality # ////////////////////////////////////////////////////////////////////////////////////// -CHECK_GROUPS: false + OPENID_CONNECT: # any OIDC IDP that does not differ from the generic implementation can be # configured without code changes @@ -116,6 +116,21 @@ OPENID_CONNECT: multifactor_auth_claim_info: # optional, include if you're using arborist to enforce mfa on a per-file level claim: '' # claims field that indicates mfa, either the acr or acm claim. values: [ "" ] # possible values that indicate mfa was used. At least one value configured here is required to be in the token + # is_authz_groups_sync_enabled: A configuration flag that determines whether the application should + # verify and synchronize user group memberships between the identity provider (IdP) + # and the local authorization system (Arborist). When enabled, the system retrieves + # the user's group information from their token issued by the IdP and compares it against + # the groups defined in the local system. Based on the comparison, the user is added to + # or removed from relevant groups in the local system to ensure their group memberships + # remain up-to-date. If this flag is disabled, no group synchronization occurs + is_authz_groups_sync_enabled: true + authz_groups_sync: + # This defines the prefix used to identify authorization groups. + group_prefix: "some_prefix" + # This flag indicates whether the audience (aud) claim in the JWT should be verified during token validation. + verify_aud: true + # This specifies the expected audience (aud) value for the JWT, ensuring that the token is intended for use with the 'fence' service. + audience: fence # These Google values must be obtained from Google's Cloud Console # Follow: https://developers.google.com/identity/protocols/OpenIDConnect # diff --git a/fence/error_handler.py b/fence/error_handler.py index 446da60b4..6ac6f99dc 100644 --- a/fence/error_handler.py +++ b/fence/error_handler.py @@ -28,8 +28,10 @@ def get_error_response(error: Exception): ) ) - - #raise error + # TODO: Issue: Error messages are obfuscated, the line below needs be + # uncommented when troubleshooting errors. + # Breaks tests if not commented out / removed. We need a fix for this. + # raise error # don't include internal details in the public error message # to do this, only include error messages for known http status codes diff --git a/fence/job/access_token_updater.py b/fence/job/access_token_updater.py index 28456803d..7181f4075 100644 --- a/fence/job/access_token_updater.py +++ b/fence/job/access_token_updater.py @@ -45,7 +45,7 @@ def __init__( self.visa_types = config.get("USERSYNC", {}).get("visa_types", {}) - #introduce list on self which contains all clients that need update + # introduce list on self which contains all clients that need update self.oidc_clients_requiring_token_refresh = [] # keep this as a special case, because RAS will not set group information configuration. @@ -54,7 +54,6 @@ def __init__( if "ras" not in oidc: self.logger.error("RAS client not configured") else: - #instead of setting self.ras_client add the RASClient to self.oidc_clients_requiring_token_refresh ras_client = RASClient( oidc["ras"], HTTP_PROXY=config.get("HTTP_PROXY"), @@ -62,20 +61,17 @@ def __init__( ) self.oidc_clients_requiring_token_refresh.append(ras_client) - #initialise a client for each OIDC client in oidc, which does has group information set to true and add them + # Initialise a client for each OIDC client in oidc, which does has gis_authz_groups_sync_enabled set to true and add them # to oidc_clients_requiring_token_refresh - if config["CHECK_GROUPS"]: - for oidc_name in oidc: - if "groups" in oidc.get(oidc_name): - groups = oidc.get(oidc_name).get("groups") - if groups.get("read_group_information", False): - oidc_client = OIDCClient( - settings=oidc[oidc_name], - HTTP_PROXY=config.get("HTTP_PROXY"), - logger=logger, - idp=oidc_name - ) - self.oidc_clients_requiring_token_refresh.append(oidc_client) + for oidc_name in oidc: + if oidc.get(oidc_name).get("is_authz_groups_sync_enabled", False): + oidc_client = OIDCClient( + settings=oidc[oidc_name], + HTTP_PROXY=config.get("HTTP_PROXY"), + logger=logger, + idp=oidc_name, + ) + self.oidc_clients_requiring_token_refresh.append(oidc_client) async def update_tokens(self, db_session): """ @@ -89,7 +85,7 @@ async def update_tokens(self, db_session): """ start_time = time.time() - #Change this line to reflect we are refreshing tokens, not just visas + # Change this line to reflect we are refreshing tokens, not just visas self.logger.info("Initializing Visa Update and Token refreshing Cronjob . . .") self.logger.info("Total concurrency size: {}".format(self.concurrency)) self.logger.info("Total thread pool size: {}".format(self.thread_pool_size)) @@ -181,13 +177,13 @@ async def updater(self, name, updater_queue, db_session): pkey_cache=self.pkey_cache, db_session=db_session, ) + else: self.logger.debug( f"Updater {name} NOT updating authorization for " f"user {user.username} because no client was found for IdP: {user.identity_provider}" ) - # Only mark the task as done if processing succeeded updater_queue.task_done() except Exception as exc: @@ -195,19 +191,20 @@ async def updater(self, name, updater_queue, db_session): f"Updater {name} could not update authorization " f"for {user.username if user else 'unknown user'}. Error: {exc}. Continuing." ) - # Still mark the task as done even if there was an exception + # Ensure task is marked done if exception occurs updater_queue.task_done() def _pick_client(self, user): """ Select OIDC client based on identity provider. """ - # change this logic to return any client which is in self.oidc_clients_requiring_token_refresh (check against "name") self.logger.info(f"Selecting client for user {user.username}") client = None for oidc_client in self.oidc_clients_requiring_token_refresh: if getattr(user.identity_provider, "name") == oidc_client.idp: - self.logger.info(f"Picked client: {oidc_client.idp} for user {user.username}") + self.logger.info( + f"Picked client: {oidc_client.idp} for user {user.username}" + ) client = oidc_client break if not client: diff --git a/fence/resources/openid/idp_oauth2.py b/fence/resources/openid/idp_oauth2.py index 5f0f3ee41..ce92d2d77 100644 --- a/fence/resources/openid/idp_oauth2.py +++ b/fence/resources/openid/idp_oauth2.py @@ -21,7 +21,13 @@ class Oauth2ClientBase(object): """ def __init__( - self, settings, logger, idp, arborist=None, scope=None, discovery_url=None, HTTP_PROXY=None + self, + settings, + logger, + idp, + scope=None, + discovery_url=None, + HTTP_PROXY=None, ): self.logger = logger self.settings = settings @@ -40,38 +46,26 @@ def __init__( ) self.idp = idp # display name for use in logs and error messages self.HTTP_PROXY = HTTP_PROXY - self.check_groups = config.get("CHECK_GROUPS", False) - self.groups = self.settings.get("groups", None) - self.read_group_information = False self.groups_from_idp = [] self.verify_aud = self.settings.get("verify_aud", False) self.audience = self.settings.get("audience", self.settings.get("client_id")) - self.is_mfa_enabled = "multifactor_auth_claim_info" in self.settings + self.client_id = self.settings.get("client_id", "") + self.client_secret = self.settings.get("client_secret", "") self.arborist = ArboristClient( arborist_base_url=config["ARBORIST"], logger=logger, ) - if not self.discovery_url and not settings.get("discovery"): self.logger.warning( f"OAuth2 Client for {self.idp} does not have a valid 'discovery_url'. " f"Some calls for this client may fail if they rely on the OIDC Discovery page. Use 'discovery' to configure clients without a discovery page." ) - # implent boolean setting read from settings here. read_group_information - # if set to yes, then the following needs to happen: - # 1. in the discovery_doc, response_types_supported needs to contain "code" // this seems to be assumed in the implementation - # 2. the discovery_doc (if it provides "claims_supported", then "claims_supported" needs to contain "groups" - # 2.1 groups claim is not standard in claims_supported, i.e. does not exists in keycloak and configurable. - # - # Implement a string setting "group_prefix", this is used to have namespaced groups in case of multi system OIDC - # - # implement a string setting "audience" here, implement a boolean "check_audience" here. - # if the audience is not set, but check_audience is spit out an ERROR that the audience is not set. - if self.groups: - self.read_group_information = self.groups.get("read_group_information", False) + self.read_authz_groups_from_tokens = self.settings.get( + "is_authz_groups_sync_enabled", False + ) @cached_property def discovery_doc(self): @@ -91,7 +85,6 @@ def get_token(self, token_endpoint, code): url=token_endpoint, code=code, proxies=self.get_proxies() ) - def get_jwt_keys(self, jwks_uri): """ Get jwt keys from provider's api @@ -108,23 +101,45 @@ def get_jwt_keys(self, jwks_uri): return None return resp.json()["keys"] - def decode_token(self, token_id, keys): + def decode_token_with_aud(self, token_id, keys): + """ + Decode a given JWT (JSON Web Token) using the provided keys and validate the audience, if enabled. + The subclass can override audience validation if necessary. + + Parameters: + - token_id (str): The JWT token to decode. + - keys (list): The set of keys used for decoding the token, typically retrieved from the IdP (Identity Provider). + + Returns: + - dict: The decoded token containing claims (such as user identity, groups, etc.) if the token is successfully validated. + + Raises: + - JWTClaimsError: If the token's claims (such as audience) do not match the expected values. + - JWTError: If there is a problem with the JWT token structure or verification. + + Notes: + - This function verifies the audience (`aud`) claim if `verify_aud` is set. + - The function expects the token to be signed using the RS256 algorithm. + """ try: - decoded_token = jwt.decode( + decoded_token = jwt.decode( token_id, keys, options={"verify_aud": self.verify_aud, "verify_at_hash": False}, algorithms=["RS256"], - audience=self.audience + audience=self.audience, + ) + self.logger.info( + f"Token decoded successfully for audience: {self.audience}" ) - return decoded_token except JWTClaimsError as e: self.logger.error(f"Claim error: {e}") - raise JWTClaimsError("Invalid audience") + raise JWTClaimsError(f"Invalid audience: {e}") except JWTError as e: - self.logger.error(e) + self.logger.error(f"JWT error: {e}") + raise JWTError(f"JWT error occurred: {e}") def get_jwt_claims_identity(self, token_endpoint, jwks_endpoint, code): """ @@ -139,8 +154,7 @@ def get_jwt_claims_identity(self, token_endpoint, jwks_endpoint, code): # validate audience and hash. also ensure that the algorithm is correctly derived from the token. # hash verification has not been implemented yet - return self.decode_token(token["id_token"], keys), refresh_token - + return self.decode_token_with_aud(token["id_token"], keys), refresh_token def get_value_from_discovery_doc(self, key, default_value): """ @@ -221,14 +235,29 @@ def get_auth_info(self, code): try: token_endpoint = self.get_value_from_discovery_doc("token_endpoint", "") jwks_endpoint = self.get_value_from_discovery_doc("jwks_uri", "") - claims, refresh_token = self.get_jwt_claims_identity(token_endpoint, jwks_endpoint, code) + claims, refresh_token = self.get_jwt_claims_identity( + token_endpoint, jwks_endpoint, code + ) groups = None group_prefix = None - if self.read_group_information: - groups = claims.get("groups") - group_prefix = self.settings.get("groups").get("group_prefix") + if self.read_authz_groups_from_tokens: + try: + groups = claims.get("groups") + group_prefix = self.settings.get("authz_groups_sync", {}).get( + "group_prefix", "" + ) + except (AttributeError, TypeError) as e: + self.logger( + f"Error: is_authz_groups_sync_enabled is enabled, required values not configured: {e}" + ) + raise Exception(e) + except KeyError as e: + self.logger( + f"Error: is_authz_groups_sync_enabled is enabled, however groups not found in claims: {e}" + ) + raise Exception(e) if claims.get(user_id_field): if user_id_field == "email" and not claims.get("email_verified"): @@ -240,7 +269,7 @@ def get_auth_info(self, code): "iat": claims.get("iat"), "exp": claims.get("exp"), "groups": groups, - "group_prefix": group_prefix + "group_prefix": group_prefix, } else: self.logger.exception( @@ -252,7 +281,7 @@ def get_auth_info(self, code): self.logger.exception(f"Can't get user info from {self.idp}: {e}") return {"error": f"Can't get user info from {self.idp}"} - def get_access_token(self, user, token_endpoint, db_session=None): + def get_access_token(self, user, token_endpoint, db_session=None): """ Get access_token using a refresh_token and store new refresh in upstream_refresh_token table. """ @@ -345,7 +374,43 @@ def store_refresh_token(self, user, refresh_token, expires, db_session=None): @backoff.on_exception(backoff.expo, Exception, **DEFAULT_BACKOFF_SETTINGS) def update_user_authorization(self, user, pkey_cache, db_session=None, **kwargs): + """ + Update the user's authorization by refreshing their access token and synchronizing + their group memberships with Arborist. + + This method refreshes the user's access token using an identity provider (IdP), + retrieves and decodes the token, and optionally synchronizes the user's group + memberships between the IdP and Arborist if the `groups` configuration is enabled. + + Args: + user (User): The user object, which contains details like username and identity provider. + pkey_cache (dict): A cache of public keys used for verifying JWT signatures. + db_session (SQLAlchemy Session, optional): A database session object. If not provided, + it defaults to the scoped session of the current application context. + **kwargs: Additional keyword arguments. + + Raises: + Exception: If there is an issue with retrieving the access token, decoding the token, + or synchronizing the user's groups. + + Workflow: + 1. Retrieves the token endpoint and JWKS URI from the identity provider's discovery document. + 2. Uses the user's refresh token to get a new access token and persists it in the database. + 3. Decodes the ID token using the JWKS (JSON Web Key Set) retrieved from the IdP. + 4. If group synchronization is enabled: + a. Retrieves the list of groups from Arborist. + b. Retrieves the user's groups from the IdP. + c. Adds the user to groups in Arborist that match the groups from the IdP. + d. Removes the user from groups in Arborist that they are no longer part of in the IdP. + + Logging: + - Logs the group membership synchronization activities (adding/removing users from groups). + - Logs any issues encountered while refreshing the token or during group synchronization. + + Warnings: + - If groups are not received from the IdP but group synchronization is enabled, logs a warning. + """ db_session = db_session or current_app.scoped_session() expires_at = None @@ -358,52 +423,60 @@ def update_user_authorization(self, user, pkey_cache, db_session=None, **kwargs) jwks_endpoint = self.get_value_from_discovery_doc("jwks_uri", "") keys = self.get_jwt_keys(jwks_endpoint) expires_at = token["expires_at"] - decoded_token_id = self.decode_token(token_id=token["id_token"], keys=keys) + decoded_token_id = self.decode_token_with_aud( + token_id=token["id_token"], keys=keys + ) except Exception as e: err_msg = "Could not refresh token" self.logger.exception("{}: {}".format(err_msg, e)) raise - if self.groups: - if self.read_group_information: - group_prefix = self.groups.get("group_prefix") + if self.read_authz_groups_from_tokens: + group_prefix = self.settings.get("authz_groups_sync", {}).get( + "group_prefix", "" + ) - # grab all groups defined in arborist - arborist_groups = self.arborist.list_groups().get("groups") + # grab all groups defined in arborist + arborist_groups = self.arborist.list_groups().get("groups") - # grab all groups defined in idp - groups_from_idp = decoded_token_id.get("groups") + # grab all groups defined in idp + groups_from_idp = decoded_token_id.get("groups") - exp = datetime.datetime.fromtimestamp( - expires_at, - tz=datetime.timezone.utc - ) + exp = datetime.datetime.fromtimestamp(expires_at, tz=datetime.timezone.utc) - # if group name is in the list from arborist: - if groups_from_idp: - groups_from_idp = [group.removeprefix(group_prefix).lstrip('/') for group in groups_from_idp] + # if group name is in the list from arborist: + if groups_from_idp: + groups_from_idp = [ + group.removeprefix(group_prefix).lstrip("/") + for group in groups_from_idp + ] - idp_group_names = set(groups_from_idp) + idp_group_names = set(groups_from_idp) - # Add user to all matching groups from IDP - for arborist_group in arborist_groups: - if arborist_group['name'] in idp_group_names: - self.logger.info(f"Adding {user.username} to group: {arborist_group['name']}") - self.arborist.add_user_to_group( + # Add user to all matching groups from IDP + for arborist_group in arborist_groups: + if arborist_group["name"] in idp_group_names: + self.logger.info( + f"Adding {user.username} to group: {arborist_group['name']}" + ) + self.arborist.add_user_to_group( + username=user.username, + group_name=arborist_group["name"], + expires_at=exp, + ) + + # Remove user from groups in Arborist that they are not part of in IDP + for arborist_group in arborist_groups: + if arborist_group["name"] not in idp_group_names: + if user.username in arborist_group.get("users", []): + self.logger.info( + f"Removing {user.username} from group: {arborist_group['name']}" + ) + self.arborist.remove_user_from_group( username=user.username, - group_name=arborist_group['name'], - expires_at=exp + group_name=arborist_group["name"], ) - - # Remove user from groups in Arborist that they are not part of in IDP - for arborist_group in arborist_groups: - if arborist_group['name'] not in idp_group_names: - if user.username in arborist_group.get("users", []): - self.logger.info(f"Removing {user.username} from group: {arborist_group['name']}") - self.arborist.remove_user_from_group( - username=user.username, - group_name=arborist_group['name'] - ) - else: - self.logger.warning( - f"Check-groups feature is enabled, however did receive groups from idp for user: {user.username}") \ No newline at end of file + else: + self.logger.warning( + f"Check-groups feature is enabled, however did receive groups from idp for user: {user.username}" + ) diff --git a/tests/conftest.py b/tests/conftest.py index c7e6fef3b..191371a6c 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -479,6 +479,28 @@ def app(kid, rsa_private_key, rsa_public_key): mocker.unmock_functions() +@pytest.fixture +def mock_app(): + return MagicMock() + +@pytest.fixture +def mock_user(): + return MagicMock() + +@pytest.fixture +def mock_db_session(): + """Mock the database session.""" + db_session = MagicMock() + return db_session + +@pytest.fixture +def expired_mock_user(): + """Mock a user object with upstream refresh tokens.""" + user = MagicMock() + user.upstream_refresh_tokens = [ + MagicMock(refresh_token="expired_token", expires=0), # Expired token + ] + return user @pytest.fixture(scope="function") def auth_client(request): diff --git a/tests/job/test_access_token_updater.py b/tests/job/test_access_token_updater.py index 58f2be42c..87d955617 100644 --- a/tests/job/test_access_token_updater.py +++ b/tests/job/test_access_token_updater.py @@ -6,7 +6,8 @@ from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient from fence.job.access_token_updater import AccessTokenUpdater -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def event_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -17,8 +18,10 @@ def event_loop(): @pytest.fixture def run_async(event_loop): """Run an async coroutine in the current event loop.""" + def _run(coro): return event_loop.run_until_complete(coro) + return _run @@ -57,19 +60,30 @@ def mock_oidc_clients(): @pytest.fixture def access_token_updater_config(mock_oidc_clients): """Fixture to instantiate AccessTokenUpdater with mocked OIDC clients.""" - with patch("fence.config", - {"OPENID_CONNECT": {"ras": {}, "test_oidc": {"groups": {"read_group_information": True}}}, - "CHECK_GROUPS": True}): + with patch( + "fence.config", + { + "OPENID_CONNECT": { + "ras": {}, + "test_oidc": {"groups": {"read_authz_groups_from_tokens": True}}, + }, + "ENABLE_AUTHZ_GROUPS_FROM_OIDC": True, + }, + ): updater = AccessTokenUpdater() updater.oidc_clients_requiring_token_refresh = mock_oidc_clients return updater -def test_get_user_from_db(run_async, access_token_updater_config, mock_db_session, mock_users): +def test_get_user_from_db( + run_async, access_token_updater_config, mock_db_session, mock_users +): """Test the get_user_from_db method.""" mock_db_session.query().slice().all.return_value = mock_users - users = run_async(access_token_updater_config.get_user_from_db(mock_db_session, chunk_idx=0)) + users = run_async( + access_token_updater_config.get_user_from_db(mock_db_session, chunk_idx=0) + ) assert len(users) == 2 assert users[0].username == "testuser1" assert users[1].username == "testuser2" @@ -110,7 +124,14 @@ def test_worker(run_async, access_token_updater_config, mock_users): async def updater_with_timeout(updater, queue, db_session, timeout=5): return await asyncio.wait_for(updater(queue, db_session), timeout) -def test_updater(run_async, access_token_updater_config, mock_users, mock_db_session, mock_oidc_clients): + +def test_updater( + run_async, + access_token_updater_config, + mock_users, + mock_db_session, + mock_oidc_clients, +): """Test the updater method.""" updater_queue = asyncio.Queue() @@ -121,12 +142,18 @@ def test_updater(run_async, access_token_updater_config, mock_users, mock_db_ses mock_oidc_clients[0].update_user_authorization = AsyncMock() # Ensure _pick_client returns the correct client - with patch.object(access_token_updater_config, '_pick_client', return_value=mock_oidc_clients[0]): + with patch.object( + access_token_updater_config, "_pick_client", return_value=mock_oidc_clients[0] + ): # Signal the updater to stop after processing run_async(updater_queue.put(None)) # This should be an awaited call # Run the updater to process the user and update authorization - run_async(access_token_updater_config.updater("updater_1", updater_queue, mock_db_session)) + run_async( + access_token_updater_config.updater( + "updater_1", updater_queue, mock_db_session + ) + ) # Verify that the OIDC client was called with the correct user mock_oidc_clients[0].update_user_authorization.assert_called_once_with( @@ -135,6 +162,7 @@ def test_updater(run_async, access_token_updater_config, mock_users, mock_db_ses db_session=mock_db_session, ) + def test_no_client_found(run_async, access_token_updater_config, mock_users): """Test that updater does not crash if no client is found.""" updater_queue = asyncio.Queue() @@ -146,14 +174,18 @@ def test_no_client_found(run_async, access_token_updater_config, mock_users): run_async(updater_queue.put(None)) # Signal the updater to terminate # Mock the client selection to return None - with patch.object(access_token_updater_config, '_pick_client', return_value=None): + with patch.object(access_token_updater_config, "_pick_client", return_value=None): # Run the updater and ensure it skips the user with no client - run_async(access_token_updater_config.updater("updater_1", updater_queue, MagicMock())) + run_async( + access_token_updater_config.updater("updater_1", updater_queue, MagicMock()) + ) assert updater_queue.empty() # The user should still be dequeued -def test_pick_client(run_async, access_token_updater_config, mock_users, mock_oidc_clients): +def test_pick_client( + run_async, access_token_updater_config, mock_users, mock_oidc_clients +): """Test that the correct OIDC client is selected based on the user's IDP.""" # Pick the client for a RAS user client = access_token_updater_config._pick_client(mock_users[0]) diff --git a/tests/login/test_base.py b/tests/login/test_base.py index 09352945f..bf541f64a 100644 --- a/tests/login/test_base.py +++ b/tests/login/test_base.py @@ -7,9 +7,6 @@ from datetime import datetime, timedelta import time -@pytest.fixture(autouse=True) -def mock_arborist(mock_arborist_requests): - mock_arborist_requests() @patch("fence.blueprints.login.base.prepare_login_log") def test_post_login_set_mfa(app, monkeypatch, mock_authn_user_flask_context): diff --git a/tests/login/test_idp_oauth2.py b/tests/login/test_idp_oauth2.py index e37b1a44e..aaecd3755 100644 --- a/tests/login/test_idp_oauth2.py +++ b/tests/login/test_idp_oauth2.py @@ -45,6 +45,7 @@ def test_has_mfa_claim_acr(oauth_client_acr): has_mfa = oauth_client_acr.has_mfa_claim({"acr": "mfa"}) assert has_mfa + def test_has_mfa_claim_multiple_acr(oauth_client_acr): has_mfa = oauth_client_acr.has_mfa_claim({"acr": "mfa otp duo"}) assert has_mfa @@ -85,13 +86,6 @@ def test_does_not_has_mfa_claim_multiple_amr(oauth_client_amr): has_mfa = oauth_client_amr.has_mfa_claim({"amr": ["pwd, trustme"]}) assert not has_mfa -@pytest.fixture -def mock_app(): - return MagicMock() - -@pytest.fixture -def mock_user(): - return MagicMock() # To test the store_refresh_token method of the Oauth2ClientBase class def test_store_refresh_token(mock_user, mock_app): @@ -105,24 +99,30 @@ def test_store_refresh_token(mock_user, mock_app): "client_secret": "test_client_secret", "redirect_url": "http://localhost/callback", "discovery_url": "http://localhost/.well-known/openid-configuration", - "groups": {"read_group_information": True, "group_prefix": "/"}, + "groups": {"read_authz_groups_from_tokens": True, "group_prefix": "/"}, "user_id_field": "sub", } # Ensure oauth_client is correctly instantiated - oauth_client = Oauth2ClientBase(settings=mock_settings, logger=mock_logger, idp="test_idp") + oauth_client = Oauth2ClientBase( + settings=mock_settings, logger=mock_logger, idp="test_idp" + ) refresh_token = "mock_refresh_token" expires = 1700000000 # Patch the UpstreamRefreshToken to prevent actual database interactions - with patch('fence.resources.openid.idp_oauth2.UpstreamRefreshToken', autospec=True) as MockUpstreamRefreshToken: + with patch( + "fence.resources.openid.idp_oauth2.UpstreamRefreshToken", autospec=True + ) as MockUpstreamRefreshToken: # Mock the db_session's object_session method to return a mocked session object mock_session = MagicMock() mock_app.arborist.object_session.return_value = mock_session # Call the method to test - oauth_client.store_refresh_token(mock_user, refresh_token, expires, db_session=mock_app.arborist) + oauth_client.store_refresh_token( + mock_user, refresh_token, expires, db_session=mock_app.arborist + ) # Check if UpstreamRefreshToken was instantiated correctly MockUpstreamRefreshToken.assert_called_once_with( @@ -136,12 +136,20 @@ def test_store_refresh_token(mock_user, mock_app): mock_session.add.assert_called_once_with(MockUpstreamRefreshToken.return_value) mock_app.arborist.commit.assert_called_once() + # To test if a user is granted access using the get_auth_info method in the Oauth2ClientBase -@patch('fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys') -@patch('fence.resources.openid.idp_oauth2.jwt.decode') -@patch('authlib.integrations.requests_client.OAuth2Session.fetch_token') -@patch('fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_value_from_discovery_doc') -def test_get_auth_info_granted_access(mock_get_value_from_discovery_doc, mock_fetch_token, mock_jwt_decode, mock_get_jwt_keys): +@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys") +@patch("fence.resources.openid.idp_oauth2.jwt.decode") +@patch("authlib.integrations.requests_client.OAuth2Session.fetch_token") +@patch( + "fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_value_from_discovery_doc" +) +def test_get_auth_info_granted_access( + mock_get_value_from_discovery_doc, + mock_fetch_token, + mock_jwt_decode, + mock_get_jwt_keys, +): """ Test that the `get_auth_info` method correctly retrieves, processes, and decodes an OAuth2 authentication token, including access, refresh, and ID tokens, while also @@ -155,35 +163,33 @@ def test_get_auth_info_granted_access(mock_get_value_from_discovery_doc, mock_fe "client_secret": "test_client_secret", "redirect_url": "http://localhost/callback", "discovery_url": "http://localhost/.well-known/openid-configuration", - "groups": {"read_group_information": True, "group_prefix": "/"}, + "is_authz_groups_sync_enabled": True, + "authz_groups_sync:": {"group_prefix": "/"}, "user_id_field": "sub", } # Mock logger mock_logger = MagicMock() - oauth2_client = Oauth2ClientBase(settings=mock_settings, logger=mock_logger, idp="test_idp") + oauth2_client = Oauth2ClientBase( + settings=mock_settings, logger=mock_logger, idp="test_idp" + ) # Directly mock the return values for token_endpoint and jwks_uri - mock_get_value_from_discovery_doc.side_effect = lambda key, default=None: \ + mock_get_value_from_discovery_doc.side_effect = lambda key, default=None: ( "http://localhost/token" if key == "token_endpoint" else "http://localhost/jwks" + ) # Setup mock response for fetch_token mock_fetch_token.return_value = { "access_token": "mock_access_token", "id_token": "mock_id_token", - "refresh_token": "mock_refresh_token" + "refresh_token": "mock_refresh_token", } # Setup mock JWT keys response mock_get_jwt_keys.return_value = [ - { - "kty": "RSA", - "kid": "1e9gdk7", - "use": "sig", - "n": "example-key", - "e": "AQAB" - } + {"kty": "RSA", "kid": "1e9gdk7", "use": "sig", "n": "example-key", "e": "AQAB"} ] # Setup mock decoded JWT token @@ -192,17 +198,17 @@ def test_get_auth_info_granted_access(mock_get_value_from_discovery_doc, mock_fe "email_verified": True, "iat": 1609459200, "exp": 1609462800, - "groups": ["group1", "group2"] + "groups": ["group1", "group2"], } - # Log mock setups - print(f"Mock token endpoint: {mock_get_value_from_discovery_doc('token_endpoint', '')}") + print( + f"Mock token endpoint: {mock_get_value_from_discovery_doc('token_endpoint', '')}" + ) print(f"Mock jwks_uri: {mock_get_value_from_discovery_doc('jwks_uri', '')}") print(f"Mock fetch_token response: {mock_fetch_token.return_value}") print(f"Mock JWT decode response: {mock_jwt_decode.return_value}") - # Call the method code = "mock_code" auth_info = oauth2_client.get_auth_info(code) @@ -224,21 +230,6 @@ def test_get_auth_info_granted_access(mock_get_value_from_discovery_doc, mock_fe assert auth_info["groups"] == ["group1", "group2"] -@pytest.fixture -def mock_db_session(): - """Mock the database session.""" - db_session = MagicMock() - return db_session - -@pytest.fixture -def expired_mock_user(): - """Mock a user object with upstream refresh tokens.""" - user = MagicMock() - user.upstream_refresh_tokens = [ - MagicMock(refresh_token="expired_token", expires=0), # Expired token - ] - return user - def test_get_access_token_expired(expired_mock_user, mock_db_session): """ Test that attempting to retrieve an access token for a user with an expired refresh token @@ -253,18 +244,24 @@ def test_get_access_token_expired(expired_mock_user, mock_db_session): "client_secret": "test_client_secret", "redirect_url": "http://localhost/callback", "discovery_url": "http://localhost/.well-known/openid-configuration", - "groups": {"read_group_information": True, "group_prefix": "/"}, + "is_authz_groups_sync_enabled": True, + "authz_groups_sync:": {"group_prefix": "/"}, "user_id_field": "sub", } # Initialize the Oauth2 client object - oauth2_client = Oauth2ClientBase(settings=mock_settings, logger=MagicMock(), idp="test_idp") - + oauth2_client = Oauth2ClientBase( + settings=mock_settings, logger=MagicMock(), idp="test_idp" + ) - #Simulate the token expiration and user not having access + # Simulate the token expiration and user not having access with pytest.raises(AuthError) as excinfo: print("get_access_token about to be called") - oauth2_client.get_access_token(expired_mock_user, token_endpoint="https://token.endpoint", db_session=mock_db_session) + oauth2_client.get_access_token( + expired_mock_user, + token_endpoint="https://token.endpoint", + db_session=mock_db_session, + ) print(f"Raised exception message: {excinfo.value}") @@ -274,7 +271,7 @@ def test_get_access_token_expired(expired_mock_user, mock_db_session): mock_db_session.commit.assert_called() -@patch('fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_auth_info') +@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_auth_info") def test_post_login_with_group_prefix(mock_get_auth_info, app): """ Test the `post_login` method of the `DefaultOAuth2Callback` class, ensuring that user groups @@ -283,7 +280,7 @@ def test_post_login_with_group_prefix(mock_get_auth_info, app): """ with app.app_context(): yield - with patch.dict(config, {"CHECK_GROUPS": True}, clear=False): + with patch.dict(config, {"ENABLE_AUTHZ_GROUPS_FROM_OIDC": True}, clear=False): mock_user = MagicMock() mock_user.username = "test_user" mock_user.id = "user_id" @@ -292,15 +289,9 @@ def test_post_login_with_group_prefix(mock_get_auth_info, app): # Set up mock responses for user info and groups from the IdP mock_get_auth_info.return_value = { "username": "test_user", - "groups": [ - "group1", - "group2", - "covid/group3", - "group4", - "group5" - ], + "groups": ["group1", "group2", "covid/group3", "group4", "group5"], "exp": datetime.datetime.now(tz=datetime.timezone.utc).timestamp(), - "group_prefix": "covid/" + "group_prefix": "covid/", } # Mock the Arborist client and its methods @@ -310,7 +301,7 @@ def test_post_login_with_group_prefix(mock_get_auth_info, app): {"name": "group1"}, {"name": "group2"}, {"name": "group3"}, - {"name": "reviewers"} + {"name": "reviewers"}, ] } mock_arborist.add_user_to_group = MagicMock() @@ -322,9 +313,7 @@ def test_post_login_with_group_prefix(mock_get_auth_info, app): # Create the callback object with the mock app callback = DefaultOAuth2Callback( - idp_name="generic3", - client=MagicMock(), - app=app + idp_name="generic3", client=MagicMock(), app=app ) # Mock user and call post_login @@ -338,35 +327,42 @@ def test_post_login_with_group_prefix(mock_get_auth_info, app): groups_from_idp=mock_get_auth_info.return_value["groups"], group_prefix=mock_get_auth_info.return_value["group_prefix"], expires_at=mock_get_auth_info.return_value["exp"], - username=mock_user.username + username=mock_user.username, ) # Assertions to check if groups were processed with the correct prefix mock_arborist.add_user_to_group.assert_any_call( - username='test_user', - group_name='group1', - expires_at=datetime.datetime.fromtimestamp(mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc) + username="test_user", + group_name="group1", + expires_at=datetime.datetime.fromtimestamp( + mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc + ), ) mock_arborist.add_user_to_group.assert_any_call( - username='test_user', - group_name='group2', - expires_at=datetime.datetime.fromtimestamp(mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc) + username="test_user", + group_name="group2", + expires_at=datetime.datetime.fromtimestamp( + mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc + ), ) mock_arborist.add_user_to_group.assert_any_call( - username='test_user', - group_name='group3', - expires_at=datetime.datetime.fromtimestamp(mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc) + username="test_user", + group_name="group3", + expires_at=datetime.datetime.fromtimestamp( + mock_get_auth_info.return_value["exp"], tz=datetime.timezone.utc + ), ) # Ensure the mock was called exactly three times (once for each group that was added) assert mock_arborist.add_user_to_group.call_count == 3 - -@patch('fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys') -@patch('authlib.integrations.requests_client.OAuth2Session.fetch_token') -@patch('fence.resources.openid.idp_oauth2.jwt.decode') # Mock jwt.decode -def test_jwt_audience_verification_fails(mock_jwt_decode, mock_fetch_token, mock_get_jwt_keys): +@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys") +@patch("authlib.integrations.requests_client.OAuth2Session.fetch_token") +@patch("fence.resources.openid.idp_oauth2.jwt.decode") # Mock jwt.decode +def test_jwt_audience_verification_fails( + mock_jwt_decode, mock_fetch_token, mock_get_jwt_keys +): """ Test the JWT audience verification failure scenario. @@ -383,7 +379,7 @@ def test_jwt_audience_verification_fails(mock_jwt_decode, mock_fetch_token, mock mock_fetch_token.return_value = { "id_token": "mock-id-token", "access_token": "mock_access_token", - "refresh_token": "mock-refresh-token" + "refresh_token": "mock-refresh-token", } # Mock JWKS response @@ -394,7 +390,7 @@ def test_jwt_audience_verification_fails(mock_jwt_decode, mock_fetch_token, mock "kid": "test-key-id", "use": "sig", "n": "mock-n-value", # Simulate RSA public key values - "e": "mock-e-value" + "e": "mock-e-value", } ] } @@ -413,28 +409,30 @@ def test_jwt_audience_verification_fails(mock_jwt_decode, mock_fetch_token, mock "redirect_url": "mock-redirect-url", "discovery_url": "http://localhost/discovery", "audience": "expected-audience", - "verify_aud": True + "verify_aud": True, }, logger=MagicMock(), - idp="mock-idp" + idp="mock-idp", ) # Invoke the method and expect JWTClaimsError to be raised with pytest.raises(JWTClaimsError, match="Invalid audience"): - client.get_jwt_claims_identity(token_endpoint="https://token.endpoint", jwks_endpoint="https://jwks.uri", code="auth_code") + client.get_jwt_claims_identity( + token_endpoint="https://token.endpoint", + jwks_endpoint="https://jwks.uri", + code="auth_code", + ) # Verify fetch_token was called correctly mock_fetch_token.assert_called_once_with( - url="https://token.endpoint", - code="auth_code", - proxies=None + url="https://token.endpoint", code="auth_code", proxies=None ) - #Verify jwt.decode was called with the mock id_token and the mocked JWKS keys + # Verify jwt.decode was called with the mock id_token and the mocked JWKS keys mock_jwt_decode.assert_called_with( "mock-id-token", # The mock token - mock_jwks_response, # The mocked keys + mock_jwks_response, # The mocked keys options={"verify_aud": True, "verify_at_hash": False}, algorithms=["RS256"], - audience="expected-audience" - ) \ No newline at end of file + audience="expected-audience", + ) diff --git a/tests/test-fence-config.yaml b/tests/test-fence-config.yaml index 3ab52a19f..8b3064988 100755 --- a/tests/test-fence-config.yaml +++ b/tests/test-fence-config.yaml @@ -69,7 +69,6 @@ SESSION_COOKIE_SECURE: true ENABLE_CSRF_PROTECTION: false -CHECK_GROUPS: false # ////////////////////////////////////////////////////////////////////////////////////// # OPEN ID CONNECT (OIDC) # - Fully configure at least one client so login works @@ -150,9 +149,21 @@ OPENID_CONNECT: # use `discovery` to configure IDPs that do not expose a discovery # endpoint. One of `discovery_url` or `discovery` should be configured discovery_url: 'http://localhost/realms/generic3/.well-known/openid-configuration' - groups: - read_group_information: true + # is_authz_groups_sync_enabled: A configuration flag that determines whether the application should + # verify and synchronize user group memberships between the identity provider (IdP) + # and the local authorization system (Arborist). When enabled, the system retrieves + # the user's group information from their token issued by the IdP and compares it against + # the groups defined in the local system. Based on the comparison, the user is added to + # or removed from relevant groups in the local system to ensure their group memberships + # remain up-to-date. If this flag is disabled, no group synchronization occurs + is_authz_groups_sync_enabled: true + authz_groups_sync: + # This defines the prefix used to identify authorization groups. group_prefix: /covid + # This flag indicates whether the audience (aud) claim in the JWT should be verified during token validation. + verify_aud: true + # This specifies the expected audience (aud) value for the JWT, ensuring that the token is intended for use with the 'fence' service. + audience: fence # these are the *possible* scopes a client can be given, NOT scopes that are # given to all clients. You can be more restrictive during client creation