Skip to content
This repository has been archived by the owner on Dec 15, 2024. It is now read-only.

Commit

Permalink
feat: add SpringMockUserExtension to fake users
Browse files Browse the repository at this point in the history
Compatible with Spring 2 and Spring security 5.
  • Loading branch information
Ronny Bräunlich committed Nov 30, 2024
1 parent ea1da9b commit 80347a0
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 1 deletion.
5 changes: 5 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ dependencies {
implementation(libs.spring.test)
implementation(libs.kotlinx.coroutines)
implementation(libs.byteBuddy)
implementation(libs.spring.security.test)

testImplementation(libs.kotest.runner.junit5)
testImplementation(libs.kotest.framework.datatest)
testImplementation(libs.kotest.property)
testImplementation(libs.spring.boot.test)
testImplementation(libs.spring.boot.starter.webflux)
testImplementation(libs.spring.boot.starter.security)
testImplementation(libs.reactor.kotlin.extensions)
testImplementation(libs.kotlinx.coroutines.reactor)
}

tasks.withType<Test> {
Expand Down
2 changes: 1 addition & 1 deletion buildSrc/src/main/kotlin/Ci.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ object Ci {

// this is the version used for building snapshots
// .GITHUB_RUN_NUMBER-snapshot will be appended
private const val snapshotBase = "1.2.0"
private const val snapshotBase = "1.4.0"

private val githubRunNumber = System.getenv("GITHUB_RUN_NUMBER")

Expand Down
8 changes: 8 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ kotlinx-coroutines = "1.7.2"
kotest = "5.8.1"
spring = "5.3.39"
spring-boot = "2.7.16"
spring-security = "5.3.13.RELEASE"
byte-buddy = "1.14.18"
reactor-kotlin = "1.1.11"

[libraries]
kotlin-reflect = { group = "org.jetbrains.kotlin", name = "kotlin-reflect", version.ref = "kotlin" }
kotlinx-coroutines = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" }
kotlinx-coroutines-reactor = { group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-reactor", version.ref = "kotlinx-coroutines" }

kotest-framework-datatest = { group = "io.kotest", name = "kotest-framework-datatest", version.ref = "kotest" }
kotest-framework-api = { group = "io.kotest", name = "kotest-framework-api", version.ref = "kotest" }
Expand All @@ -19,9 +22,14 @@ kotest-runner-junit5 = { group = "io.kotest", name = "kotest-runner-junit5", ver
spring-context = { group = "org.springframework", name = "spring-context", version.ref = "spring" }
spring-test = { group = "org.springframework", name = "spring-test", version.ref = "spring" }
spring-boot-test = { group = "org.springframework.boot", name = "spring-boot-starter-test", version.ref = "spring-boot" }
spring-security-test = { group = "org.springframework.security", name = "spring-security-test", version.ref = "spring-security" }
spring-boot-starter-webflux = {group = "org.springframework.boot", name = "spring-boot-starter-webflux", version.ref = "spring-boot" }
spring-boot-starter-security = {group = "org.springframework.boot", name = "spring-boot-starter-security", version.ref = "spring-boot" }

byteBuddy = { group = "net.bytebuddy", name = "byte-buddy", version.ref = "byte-buddy" }

reactor-kotlin-extensions = { group = "io.projectreactor.kotlin", name = "reactor-kotlin-extensions", version.ref = "reactor-kotlin" }

[plugins]
kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" }
kotlin-spring = { id = "org.jetbrains.kotlin.plugin.spring", version.ref = "kotlin" }
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package io.kotest.extensions.spring

import io.kotest.core.extensions.Extension
import io.kotest.core.test.TestCase

internal interface BeforeSpringExtension: Extension {

/**
* Called before testContextManager().beforeTestMethod() in SpringTestExtension
*/
suspend fun beforeSpring(testCase: TestCase): Unit = Unit

}
11 changes: 11 additions & 0 deletions src/main/kotlin/io/kotest/extensions/spring/SpringTestExtension.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

package io.kotest.extensions.spring

import io.kotest.core.extensions.Extension
import io.kotest.core.extensions.SpecExtension
import io.kotest.core.extensions.TestCaseExtension
import io.kotest.core.spec.Spec
import io.kotest.core.spec.functionOverrideCallbacks
import io.kotest.core.test.TestCase
import io.kotest.core.test.TestResult
import io.kotest.core.test.TestType
Expand Down Expand Up @@ -67,6 +69,7 @@ class SpringTestExtension(private val mode: SpringTestLifecycleMode = SpringTest
override suspend fun intercept(testCase: TestCase, execute: suspend (TestCase) -> TestResult): TestResult {
val methodName = method(testCase)
if (testCase.isApplicable()) {
extensions(testCase).filterIsInstance<BeforeSpringExtension>().forEach{ it.beforeSpring(testCase)}
testContextManager().beforeTestMethod(testCase.spec, methodName)
testContextManager().beforeTestExecution(testCase.spec, methodName)
}
Expand Down Expand Up @@ -110,6 +113,14 @@ class SpringTestExtension(private val mode: SpringTestLifecycleMode = SpringTest
fakeSpec.getMethod(methodName)
}

private fun extensions(testCase: TestCase): List<Extension> {
return testCase.spec.extensions() + // overriding the extensions function in the spec
testCase.spec.listeners() + // overriding the listeners function in the spec
testCase.spec.functionOverrideCallbacks() + // spec level dsl eg beforeTest { }
testCase.spec.registeredExtensions() + // added to the spec via register
testCase.config.extensions
}

/**
* Checks for a safe class name and throws if invalid
* https://kotlinlang.org/docs/keyword-reference.html#soft-keywords
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package io.kotest.extensions.spring.security

import io.kotest.core.listeners.AfterTestListener
import io.kotest.core.test.TestCase
import io.kotest.core.test.TestResult
import io.kotest.extensions.spring.BeforeSpringExtension
import io.kotest.extensions.spring.testContextManager
import org.springframework.security.core.GrantedAuthority
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.core.context.SecurityContext
import org.springframework.security.test.context.TestSecurityContextHolder
import org.springframework.security.test.context.support.TestExecutionEvent
import org.springframework.security.test.context.support.WithMockUser
import org.springframework.security.test.context.support.WithSecurityContext
import org.springframework.security.test.context.support.WithSecurityContextFactory
import org.springframework.test.context.TestContextAnnotationUtils

class SpringMockUserExtension(
private val username: String = "user",
private val password: String = "password",
private val roles: List<String> = listOf("USER"),
private val authorities: List<String> = listOf()
) : BeforeSpringExtension, AfterTestListener {

override suspend fun beforeSpring(testCase: TestCase) {
TestSecurityContextHolder.setContext(createSecurityContext())
}

override suspend fun afterAny(testCase: TestCase, result: TestResult) {
TestSecurityContextHolder.clearContext()
}

/**
* Copied from Spring's WithMockUserSecurityContextFactory
*/
private suspend fun createSecurityContext(): SecurityContext {
// val grantedAuthorities: MutableList<GrantedAuthority> =
// authorities.map { SimpleGrantedAuthority(it) }.toMutableList()
// if (grantedAuthorities.isEmpty()) {
// for (role in roles) {
// require(!role.startsWith("ROLE_")) { "roles cannot start with ROLE_ Got $role" }
// grantedAuthorities.add(SimpleGrantedAuthority("ROLE_$role"))
// }
// } else check(roles.size == 1 && "USER" == roles[0]) {
// ("You cannot define roles attribute " + roles
// + " with authorities attribute " + authorities)
// }
val withSecurityContextAnnotationDescriptor =
TestContextAnnotationUtils.findAnnotationDescriptor(
WithMockUser::class.java,
WithSecurityContext::class.java
)!!.annotation
val factoryClazz: Class<out WithSecurityContextFactory<out Annotation>> =
withSecurityContextAnnotationDescriptor.factory.java
val factory =
testContextManager().testContext.applicationContext.autowireCapableBeanFactory.createBean(
factoryClazz
) as WithSecurityContextFactory<WithMockUser>
val withMockUser = WithMockUser(
value = username,
username = username,
roles = roles.toTypedArray(),
authorities = authorities.toTypedArray(),
password = password,
setupBefore = TestExecutionEvent.TEST_METHOD
)
return factory.createSecurityContext(withMockUser)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.kotest.extensions.spring.security

import io.kotest.core.spec.style.DescribeSpec
import io.kotest.core.spec.style.FunSpec
import io.kotest.extensions.spring.Components
import io.kotest.extensions.spring.SpringExtension
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient
import org.springframework.boot.test.autoconfigure.web.reactive.WebFluxTest
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.test.context.ContextConfiguration
import org.springframework.test.web.reactive.server.WebTestClient
import java.util.UUID

@SpringBootTest(classes = [SpringTestApplication::class])
@AutoConfigureWebTestClient
class SpringMockUserExtensionIntegrationTest(
@Autowired private val webTestClient: WebTestClient,
) : DescribeSpec() {

override fun extensions() = listOf(SpringExtension)

init {
describe("ADMIN") {
extensions(SpringMockUserExtension(authorities = listOf("ADMIN")))
it("should provide mock authentication") {
webTestClient
.get()
.uri("/secure")
.exchange()
.expectStatus()
.isOk
}
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package io.kotest.extensions.spring.security

import io.kotest.assertions.throwables.shouldThrow
import io.kotest.core.descriptors.append
import io.kotest.core.descriptors.toDescriptor
import io.kotest.core.names.TestName
import io.kotest.core.source.sourceRef
import io.kotest.core.spec.style.FunSpec
import io.kotest.core.test.TestCase
import io.kotest.core.test.TestResult
import io.kotest.core.test.TestType
import io.kotest.extensions.spring.SpringExtension
import io.kotest.matchers.collections.shouldContainOnly
import io.kotest.matchers.equals.shouldBeEqual
import io.kotest.matchers.shouldBe
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.security.core.userdetails.User
import org.springframework.security.test.context.TestSecurityContextHolder
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken

class SpringMockUserExtensionTest : FunSpec() {

override fun extensions() = listOf(SpringExtension)

init {

afterAny {
SecurityContextHolder.getContext().authentication = null
}

test("should set user before Spring in Spring security context") {
SpringMockUserExtension().beforeSpring(
TestCase(
descriptor = SpringMockUserExtensionTest::class.toDescriptor().append("aaa"),
name = TestName("name"),
spec = this@SpringMockUserExtensionTest,
test = {},
source = sourceRef(),
type = TestType.Test
)
)

val expectedAuthorities = mutableListOf(SimpleGrantedAuthority("ROLE_USER"))
val expectedPrincipal = User("user", "password", true, true, true, true, expectedAuthorities)
val expectedAuthentication = UsernamePasswordAuthenticationToken.authenticated(
expectedPrincipal,
expectedPrincipal.password,
expectedPrincipal.authorities
)
SecurityContextHolder.getContext().authentication shouldBeEqual expectedAuthentication
TestSecurityContextHolder.getContext().authentication shouldBeEqual expectedAuthentication
}

test("should remove user after any from Spring security context") {
SecurityContextHolder.getContext().authentication =
PreAuthenticatedAuthenticationToken(
null,
null,
emptyList()
)
SpringMockUserExtension().afterAny(
TestCase(
descriptor = SpringMockUserExtensionTest::class.toDescriptor().append("aaa"),
name = TestName("name"),
spec = this@SpringMockUserExtensionTest,
test = {},
source = sourceRef(),
type = TestType.Test
),
TestResult.Ignored
)

SecurityContextHolder.getContext().authentication shouldBe null
TestSecurityContextHolder.getContext().authentication shouldBe null
}

test("should assign roles to authentication") {
SpringMockUserExtension(roles = listOf("TEST")).beforeSpring(
TestCase(
descriptor = SpringMockUserExtensionTest::class.toDescriptor().append("aaa"),
name = TestName("name"),
spec = this@SpringMockUserExtensionTest,
test = {},
source = sourceRef(),
type = TestType.Test
)
)

SecurityContextHolder.getContext().authentication.authorities shouldContainOnly listOf(SimpleGrantedAuthority("ROLE_TEST"))
}


test("should reject roles starting with ROLE") {
shouldThrow<IllegalArgumentException> {
SpringMockUserExtension(roles = listOf("ROLE_TEST")).beforeSpring(
TestCase(
descriptor = SpringMockUserExtensionTest::class.toDescriptor().append("aaa"),
name = TestName("name"),
spec = this@SpringMockUserExtensionTest,
test = {},
source = sourceRef(),
type = TestType.Test
)
)
}
}

test("should assign authorities to authentication") {
SpringMockUserExtension(authorities = listOf("ADMIN")).beforeSpring(
TestCase(
descriptor = SpringMockUserExtensionTest::class.toDescriptor().append("aaa"),
name = TestName("name"),
spec = this@SpringMockUserExtensionTest,
test = {},
source = sourceRef(),
type = TestType.Test
)
)

SecurityContextHolder.getContext().authentication.authorities shouldContainOnly listOf(SimpleGrantedAuthority("ADMIN"))
}

test("should not assign roles and authorities") {
shouldThrow<IllegalStateException> {
SpringMockUserExtension(roles = listOf("ROLE_TEST"), authorities = listOf("ADMIN")).beforeSpring(
TestCase(
descriptor = SpringMockUserExtensionTest::class.toDescriptor().append("aaa"),
name = TestName("name"),
spec = this@SpringMockUserExtensionTest,
test = {},
source = sourceRef(),
type = TestType.Test
)
)
}
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.kotest.extensions.spring.security

import org.springframework.boot.autoconfigure.SpringBootApplication
import org.springframework.http.HttpStatus
import org.springframework.security.access.prepost.PreAuthorize
import org.springframework.security.config.annotation.method.configuration.EnableReactiveMethodSecurity
import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.ResponseStatus
import org.springframework.web.bind.annotation.RestController

@SpringBootApplication
@EnableWebFluxSecurity
@EnableReactiveMethodSecurity
class SpringTestApplication

@RestController
@RequestMapping("/secure")
class TestSecurityController {

@PreAuthorize("hasAuthority('ADMIN')")
@GetMapping
@ResponseStatus(HttpStatus.OK)
suspend fun test() = "Ok"
}

0 comments on commit 80347a0

Please sign in to comment.