Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LX-77 Add support for JIT provisioning pt 2 #7

Merged
merged 1 commit into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.springframework.web.server.ResponseStatusException
/**
* `AuthenticationStoreClient` defines methods for retrieving identity objects from persistent storage.
*/
@SuppressWarnings("TooManyFunctions")
interface AuthenticationStoreClient {

/**
Expand Down Expand Up @@ -81,7 +82,7 @@ interface AuthenticationStoreClient {
* @param lastName last name of the user
* @param email email of the user
* @param userGroups list of user groups where the user belongs to
* Returns created [User]
* @return created [User]
*/
@SuppressWarnings("LongParameterList")
suspend fun createUser(
Expand All @@ -93,6 +94,12 @@ interface AuthenticationStoreClient {
userGroups: List<String>
): User

/**
* Patches [User] in the given `organizationId`
* @return updated [User]
*/
suspend fun patchUser(organizationId: String, user: User): User

/**
*
* Retrieves [List<JWK>] that belongs to given `organizationId`
Expand Down Expand Up @@ -197,4 +204,8 @@ data class User(
val lastLogoutAllTimestamp: Instant? = null,
val usedTokenId: String? = null,
val name: String? = null,
var firstname: String? = null,
var lastname: String? = null,
var email: String? = null,
var userGroups: List<String>? = null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ fun OAuth2AuthenticationToken.getClaim(claimName: String?): String =
(principal.attributes[claimName ?: IdTokenClaimNames.SUB] as String?)
?: throw InvalidBearerTokenException("Token does not contain $claimName claim.")

fun OAuth2AuthenticationToken.getClaimList(claimName: String?): List<String> =
(principal.attributes[claimName] as List<String>?) ?: emptyList()
fun OAuth2AuthenticationToken.getClaimList(claimName: String?): List<String>? =
(principal.attributes[claimName] as List<String>?)

/**
* Detect if character is legal according to OAuth2 specification
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,33 @@ class JitProvisioningAuthenticationSuccessHandler(
if (organization.jitEnabled == true) {
checkMandatoryClaims(authenticationToken, organization.id)
logMessage("Initiating JIT provisioning", "started", organization.id)
val user: User? = authenticationStoreClient.getUserByAuthenticationId(
organization.id,
authenticationToken.getClaim(organization.oauthSubjectIdClaim)
)
val subClaim = authenticationToken.getClaim(organization.oauthSubjectIdClaim)
val firstnameClaim = authenticationToken.getClaim(GIVEN_NAME)
val lastnameClaim = authenticationToken.getClaim(FAMILY_NAME)
val emailClaim = authenticationToken.getClaim(EMAIL)
val userGroupsClaim = authenticationToken.getClaimList(GD_USER_GROUPS)
val user: User? = authenticationStoreClient.getUserByAuthenticationId(organization.id, subClaim)
if (user != null) {
logMessage("Checking for user update", "running", organization.id)
// TODO finish user update
if (userDetailsChanged(user, firstnameClaim, lastnameClaim, emailClaim, userGroupsClaim)) {
logMessage("User details changed, patching", "running", organization.id)
user.firstname = firstnameClaim
user.lastname = lastnameClaim
user.email = emailClaim
user.userGroups = userGroupsClaim
authenticationStoreClient.patchUser(organization.id, user)
} else {
logMessage("User not changed, skipping update", "finished", organization.id)
}
} else {
logMessage("Creating user", "running", organization.id)
val provisionedUser = authenticationStoreClient.createUser(
organization.id,
authenticationToken.getClaim(organization.oauthSubjectIdClaim),
authenticationToken.getClaim(GIVEN_NAME),
authenticationToken.getClaim(FAMILY_NAME),
authenticationToken.getClaim(EMAIL),
authenticationToken.getClaimList(GD_USER_GROUPS)
subClaim,
firstnameClaim,
lastnameClaim,
emailClaim,
userGroupsClaim ?: emptyList()
)
logMessage("User ${provisionedUser.id} created in organization", "finished", organization.id)
}
Expand All @@ -81,22 +92,28 @@ class JitProvisioningAuthenticationSuccessHandler(
* Thrown when OAuth2AuthenticationToken is missing mandatory claims.
*/
class MissingMandatoryClaimsException(missingClaims: List<String>) : OAuth2AuthenticationException(
OAuth2Error(
OAuth2ErrorCodes.INVALID_TOKEN,
"Missing mandatory claims: $missingClaims",
null
)
OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, "Missing mandatory claims: $missingClaims", null)
)

private fun checkMandatoryClaims(authenticationToken: OAuth2AuthenticationToken, organizationId: String) {
val mandatoryClaims = setOf(GIVEN_NAME, FAMILY_NAME, EMAIL)
val missingClaims = mandatoryClaims.filter { it !in authenticationToken.principal.attributes }
if (missingClaims.isNotEmpty()) {
logMessage("Authentication token is missing mandatory claim(s): $missingClaims", "error", organizationId)
throw MissingMandatoryClaimsException(missingClaims)
}
}

private fun userDetailsChanged(
user: User,
firstname: String,
lastname: String,
email: String,
userGroups: List<String>?
): Boolean {
val userGroupsChanged = userGroups != null && user.userGroups?.equalsIgnoreOrder(userGroups) == false
return user.firstname != firstname || user.lastname != lastname || user.email != email || userGroupsChanged
}

private fun logMessage(message: String, state: String, organizationId: String) {
logger.logInfo {
withMessage { message }
Expand All @@ -106,10 +123,13 @@ class JitProvisioningAuthenticationSuccessHandler(
}
}

private fun <T> List<T>.equalsIgnoreOrder(other: List<T>) = this.size == other.size && this.toSet() == other.toSet()

companion object Claims {
const val GIVEN_NAME = "given_name"
const val FAMILY_NAME = "family_name"
const val EMAIL = "email"
const val GD_USER_GROUPS = "gd_user_groups"
val mandatoryClaims = setOf(GIVEN_NAME, FAMILY_NAME, EMAIL)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import java.util.stream.Stream
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.MethodSource
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken
import org.springframework.security.web.server.WebFilterExchange
import strikt.api.expectThat
Expand All @@ -32,13 +36,6 @@ import strikt.assertions.isNull

class JitProvisioningAuthenticationSuccessHandlerTest {

companion object {
private const val ORG_ID = "orgId"
private const val SUB = "sub"
private const val HOST = "gooddata.com"
private const val USER_ID = "userId"
}

private val client: AuthenticationStoreClient = mockk()
private val exchange: WebFilterExchange = mockk {
coEvery { exchange.request.uri.host } returns HOST
Expand Down Expand Up @@ -118,4 +115,79 @@ class JitProvisioningAuthenticationSuccessHandlerTest {
coVerify { client.getUserByAuthenticationId(ORG_ID, SUB) }
coVerify { client.createUser(ORG_ID, SUB, GIVEN_NAME, FAMILY_NAME, EMAIL, emptyList()) }
}

@ParameterizedTest(name = "{0}")
@MethodSource("users")
fun `should test user patching`(
case: String,
user: User,
patchCount: Int
) {
// given
val handler = JitProvisioningAuthenticationSuccessHandler(client)

// when
coEvery { client.getOrganizationByHostname(HOST) }.returns(
Organization(id = ORG_ID, oauthSubjectIdClaim = SUB, jitEnabled = true)
)
coEvery { client.getUserByAuthenticationId(ORG_ID, SUB) }.returns(user)
coEvery { client.patchUser(ORG_ID, any()) } returns mockk()

// then
expectThat(
handler.onAuthenticationSuccess(exchange, authentication)
.block()
).isNull()

coVerify { client.getOrganizationByHostname(HOST) }
coVerify { client.getUserByAuthenticationId(ORG_ID, SUB) }
coVerify(exactly = patchCount) { client.patchUser(ORG_ID, any()) }
}

companion object {

private const val ORG_ID = "orgId"
private const val SUB = "sub"
private const val HOST = "gooddata.com"
private const val USER_ID = "userId"

@JvmStatic
fun users() = Stream.of(
Arguments.of(
"should update user when users lastname is changed",
User(
USER_ID,
null,
firstname = GIVEN_NAME,
lastname = "NewFamilyName",
email = EMAIL,
userGroups = emptyList()
),
1
),
Arguments.of(
"should update user when users userGroups re changed",
User(
USER_ID,
null,
firstname = GIVEN_NAME,
lastname = FAMILY_NAME,
email = EMAIL,
userGroups = listOf("newUserGroup")
),
1
),
Arguments.of(
"should not update user when user details are not changed",
User(
USER_ID,
null,
firstname = GIVEN_NAME,
lastname = FAMILY_NAME,
email = EMAIL
),
0
)
)
}
}
Loading