From 9855f6182f1dc11222a9af11ccf875ccbe9c4dce Mon Sep 17 00:00:00 2001 From: Petr Jeske Date: Mon, 19 Feb 2024 10:41:28 +0100 Subject: [PATCH] LX-77 Add support for JIT provisioning pt 2 Adding JIT user update --- .../main/kotlin/AuthenticationStoreClient.kt | 13 ++- .../src/main/kotlin/AuthenticationUtils.kt | 4 +- ...rovisioningAuthenticationSuccessHandler.kt | 52 +++++++---- ...sioningAuthenticationSuccessHandlerTest.kt | 86 +++++++++++++++++-- 4 files changed, 129 insertions(+), 26 deletions(-) diff --git a/gooddata-server-oauth2-autoconfigure/src/main/kotlin/AuthenticationStoreClient.kt b/gooddata-server-oauth2-autoconfigure/src/main/kotlin/AuthenticationStoreClient.kt index 95db925..d0ac9e1 100644 --- a/gooddata-server-oauth2-autoconfigure/src/main/kotlin/AuthenticationStoreClient.kt +++ b/gooddata-server-oauth2-autoconfigure/src/main/kotlin/AuthenticationStoreClient.kt @@ -26,6 +26,7 @@ import org.springframework.web.server.ResponseStatusException /** * `AuthenticationStoreClient` defines methods for retrieving identity objects from persistent storage. */ +@SuppressWarnings("TooManyFunctions") interface AuthenticationStoreClient { /** @@ -81,7 +82,7 @@ interface AuthenticationStoreClient { * @param lastName last name of the user * @param email email of the user * @param userGroups list of user groups where the user belongs to - * Returns created [User] + * @return created [User] */ @SuppressWarnings("LongParameterList") suspend fun createUser( @@ -93,6 +94,12 @@ interface AuthenticationStoreClient { userGroups: List ): User + /** + * Patches [User] in the given `organizationId` + * @return updated [User] + */ + suspend fun patchUser(organizationId: String, user: User): User + /** * * Retrieves [List] that belongs to given `organizationId` @@ -197,4 +204,8 @@ data class User( val lastLogoutAllTimestamp: Instant? = null, val usedTokenId: String? = null, val name: String? = null, + var firstname: String? = null, + var lastname: String? = null, + var email: String? = null, + var userGroups: List? = null, ) diff --git a/gooddata-server-oauth2-autoconfigure/src/main/kotlin/AuthenticationUtils.kt b/gooddata-server-oauth2-autoconfigure/src/main/kotlin/AuthenticationUtils.kt index abdcf0a..5b56996 100644 --- a/gooddata-server-oauth2-autoconfigure/src/main/kotlin/AuthenticationUtils.kt +++ b/gooddata-server-oauth2-autoconfigure/src/main/kotlin/AuthenticationUtils.kt @@ -307,8 +307,8 @@ fun OAuth2AuthenticationToken.getClaim(claimName: String?): String = (principal.attributes[claimName ?: IdTokenClaimNames.SUB] as String?) ?: throw InvalidBearerTokenException("Token does not contain $claimName claim.") -fun OAuth2AuthenticationToken.getClaimList(claimName: String?): List = - (principal.attributes[claimName] as List?) ?: emptyList() +fun OAuth2AuthenticationToken.getClaimList(claimName: String?): List? = + (principal.attributes[claimName] as List?) /** * Detect if character is legal according to OAuth2 specification diff --git a/gooddata-server-oauth2-autoconfigure/src/main/kotlin/JitProvisioningAuthenticationSuccessHandler.kt b/gooddata-server-oauth2-autoconfigure/src/main/kotlin/JitProvisioningAuthenticationSuccessHandler.kt index 1fa38fd..290675d 100644 --- a/gooddata-server-oauth2-autoconfigure/src/main/kotlin/JitProvisioningAuthenticationSuccessHandler.kt +++ b/gooddata-server-oauth2-autoconfigure/src/main/kotlin/JitProvisioningAuthenticationSuccessHandler.kt @@ -52,22 +52,33 @@ class JitProvisioningAuthenticationSuccessHandler( if (organization.jitEnabled == true) { checkMandatoryClaims(authenticationToken, organization.id) logMessage("Initiating JIT provisioning", "started", organization.id) - val user: User? = authenticationStoreClient.getUserByAuthenticationId( - organization.id, - authenticationToken.getClaim(organization.oauthSubjectIdClaim) - ) + val subClaim = authenticationToken.getClaim(organization.oauthSubjectIdClaim) + val firstnameClaim = authenticationToken.getClaim(GIVEN_NAME) + val lastnameClaim = authenticationToken.getClaim(FAMILY_NAME) + val emailClaim = authenticationToken.getClaim(EMAIL) + val userGroupsClaim = authenticationToken.getClaimList(GD_USER_GROUPS) + val user: User? = authenticationStoreClient.getUserByAuthenticationId(organization.id, subClaim) if (user != null) { logMessage("Checking for user update", "running", organization.id) - // TODO finish user update + if (userDetailsChanged(user, firstnameClaim, lastnameClaim, emailClaim, userGroupsClaim)) { + logMessage("User details changed, patching", "running", organization.id) + user.firstname = firstnameClaim + user.lastname = lastnameClaim + user.email = emailClaim + user.userGroups = userGroupsClaim + authenticationStoreClient.patchUser(organization.id, user) + } else { + logMessage("User not changed, skipping update", "finished", organization.id) + } } else { logMessage("Creating user", "running", organization.id) val provisionedUser = authenticationStoreClient.createUser( organization.id, - authenticationToken.getClaim(organization.oauthSubjectIdClaim), - authenticationToken.getClaim(GIVEN_NAME), - authenticationToken.getClaim(FAMILY_NAME), - authenticationToken.getClaim(EMAIL), - authenticationToken.getClaimList(GD_USER_GROUPS) + subClaim, + firstnameClaim, + lastnameClaim, + emailClaim, + userGroupsClaim ?: emptyList() ) logMessage("User ${provisionedUser.id} created in organization", "finished", organization.id) } @@ -81,15 +92,10 @@ class JitProvisioningAuthenticationSuccessHandler( * Thrown when OAuth2AuthenticationToken is missing mandatory claims. */ class MissingMandatoryClaimsException(missingClaims: List) : OAuth2AuthenticationException( - OAuth2Error( - OAuth2ErrorCodes.INVALID_TOKEN, - "Missing mandatory claims: $missingClaims", - null - ) + OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, "Missing mandatory claims: $missingClaims", null) ) private fun checkMandatoryClaims(authenticationToken: OAuth2AuthenticationToken, organizationId: String) { - val mandatoryClaims = setOf(GIVEN_NAME, FAMILY_NAME, EMAIL) val missingClaims = mandatoryClaims.filter { it !in authenticationToken.principal.attributes } if (missingClaims.isNotEmpty()) { logMessage("Authentication token is missing mandatory claim(s): $missingClaims", "error", organizationId) @@ -97,6 +103,17 @@ class JitProvisioningAuthenticationSuccessHandler( } } + private fun userDetailsChanged( + user: User, + firstname: String, + lastname: String, + email: String, + userGroups: List? + ): Boolean { + val userGroupsChanged = userGroups != null && user.userGroups?.equalsIgnoreOrder(userGroups) == false + return user.firstname != firstname || user.lastname != lastname || user.email != email || userGroupsChanged + } + private fun logMessage(message: String, state: String, organizationId: String) { logger.logInfo { withMessage { message } @@ -106,10 +123,13 @@ class JitProvisioningAuthenticationSuccessHandler( } } + private fun List.equalsIgnoreOrder(other: List) = this.size == other.size && this.toSet() == other.toSet() + companion object Claims { const val GIVEN_NAME = "given_name" const val FAMILY_NAME = "family_name" const val EMAIL = "email" const val GD_USER_GROUPS = "gd_user_groups" + val mandatoryClaims = setOf(GIVEN_NAME, FAMILY_NAME, EMAIL) } } diff --git a/gooddata-server-oauth2-autoconfigure/src/test/kotlin/JitProvisioningAuthenticationSuccessHandlerTest.kt b/gooddata-server-oauth2-autoconfigure/src/test/kotlin/JitProvisioningAuthenticationSuccessHandlerTest.kt index 640ecec..a5ebe5c 100644 --- a/gooddata-server-oauth2-autoconfigure/src/test/kotlin/JitProvisioningAuthenticationSuccessHandlerTest.kt +++ b/gooddata-server-oauth2-autoconfigure/src/test/kotlin/JitProvisioningAuthenticationSuccessHandlerTest.kt @@ -23,7 +23,11 @@ import io.mockk.coEvery import io.mockk.coVerify import io.mockk.every import io.mockk.mockk +import java.util.stream.Stream import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken import org.springframework.security.web.server.WebFilterExchange import strikt.api.expectThat @@ -32,13 +36,6 @@ import strikt.assertions.isNull class JitProvisioningAuthenticationSuccessHandlerTest { - companion object { - private const val ORG_ID = "orgId" - private const val SUB = "sub" - private const val HOST = "gooddata.com" - private const val USER_ID = "userId" - } - private val client: AuthenticationStoreClient = mockk() private val exchange: WebFilterExchange = mockk { coEvery { exchange.request.uri.host } returns HOST @@ -118,4 +115,79 @@ class JitProvisioningAuthenticationSuccessHandlerTest { coVerify { client.getUserByAuthenticationId(ORG_ID, SUB) } coVerify { client.createUser(ORG_ID, SUB, GIVEN_NAME, FAMILY_NAME, EMAIL, emptyList()) } } + + @ParameterizedTest(name = "{0}") + @MethodSource("users") + fun `should test user patching`( + case: String, + user: User, + patchCount: Int + ) { + // given + val handler = JitProvisioningAuthenticationSuccessHandler(client) + + // when + coEvery { client.getOrganizationByHostname(HOST) }.returns( + Organization(id = ORG_ID, oauthSubjectIdClaim = SUB, jitEnabled = true) + ) + coEvery { client.getUserByAuthenticationId(ORG_ID, SUB) }.returns(user) + coEvery { client.patchUser(ORG_ID, any()) } returns mockk() + + // then + expectThat( + handler.onAuthenticationSuccess(exchange, authentication) + .block() + ).isNull() + + coVerify { client.getOrganizationByHostname(HOST) } + coVerify { client.getUserByAuthenticationId(ORG_ID, SUB) } + coVerify(exactly = patchCount) { client.patchUser(ORG_ID, any()) } + } + + companion object { + + private const val ORG_ID = "orgId" + private const val SUB = "sub" + private const val HOST = "gooddata.com" + private const val USER_ID = "userId" + + @JvmStatic + fun users() = Stream.of( + Arguments.of( + "should update user when users lastname is changed", + User( + USER_ID, + null, + firstname = GIVEN_NAME, + lastname = "NewFamilyName", + email = EMAIL, + userGroups = emptyList() + ), + 1 + ), + Arguments.of( + "should update user when users userGroups re changed", + User( + USER_ID, + null, + firstname = GIVEN_NAME, + lastname = FAMILY_NAME, + email = EMAIL, + userGroups = listOf("newUserGroup") + ), + 1 + ), + Arguments.of( + "should not update user when user details are not changed", + User( + USER_ID, + null, + firstname = GIVEN_NAME, + lastname = FAMILY_NAME, + email = EMAIL + ), + 0 + ) + ) + } }