Skip to content

Commit

Permalink
improved jwt token error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jamcunha committed Aug 26, 2023
1 parent ac49380 commit 9a9fa6c
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import org.springframework.http.HttpStatus
import org.springframework.http.converter.HttpMessageNotReadableException
import org.springframework.security.access.AccessDeniedException
import org.springframework.security.core.AuthenticationException
import org.springframework.security.oauth2.jwt.BadJwtException
import org.springframework.security.oauth2.jwt.JwtValidationException
import org.springframework.web.bind.MethodArgumentNotValidException
import org.springframework.web.bind.annotation.ExceptionHandler
import org.springframework.web.bind.annotation.RequestMapping
Expand All @@ -18,6 +20,7 @@ import org.springframework.web.bind.annotation.RestController
import org.springframework.web.bind.annotation.RestControllerAdvice
import org.springframework.web.multipart.MaxUploadSizeExceededException
import pt.up.fe.ni.website.backend.config.Logging
import pt.up.fe.ni.website.backend.service.ErrorMessages

data class SimpleError(
val message: String,
Expand Down Expand Up @@ -134,6 +137,22 @@ class ErrorController(private val objectMapper: ObjectMapper) : ErrorController,
return wrapSimpleError(e.message ?: "invalid authentication")
}

@ExceptionHandler(JwtValidationException::class)
@ResponseStatus(HttpStatus.UNAUTHORIZED)
fun invalidAuthentication(e: JwtValidationException): CustomError {
return if (e.message?.contains("expired") == true) {
wrapSimpleError(ErrorMessages.expiredToken)
} else {
wrapSimpleError(ErrorMessages.invalidToken)
}
}

@ExceptionHandler(BadJwtException::class)
@ResponseStatus(HttpStatus.UNAUTHORIZED)
fun invalidAuthentication(e: BadJwtException): CustomError {
return wrapSimpleError(ErrorMessages.invalidToken)
}

fun wrapSimpleError(msg: String, param: String? = null, value: Any? = null) = CustomError(
mutableListOf(SimpleError(msg, param, value))
)
Expand Down
25 changes: 4 additions & 21 deletions src/main/kotlin/pt/up/fe/ni/website/backend/service/AuthService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,7 @@ class AuthService(
}

fun refreshAccessToken(refreshToken: String): String {
val jwt =
try {
jwtDecoder.decode(refreshToken)
} catch (e: Exception) {
throw InvalidBearerTokenException(ErrorMessages.invalidRefreshToken)
}
if (jwt.expiresAt?.isBefore(Instant.now()) != false) {
throw InvalidBearerTokenException(ErrorMessages.expiredRefreshToken)
}
val jwt = jwtDecoder.decode(refreshToken)
val account = accountService.getAccountByEmail(jwt.subject)
return generateAccessToken(account)
}
Expand All @@ -72,23 +64,14 @@ class AuthService(
}

fun confirmRecoveryToken(recoveryToken: String, dto: PasswordRecoveryConfirmDto): Account {
val jwt =
try {
jwtDecoder.decode(recoveryToken)
} catch (e: Exception) {
throw InvalidBearerTokenException(ErrorMessages.invalidRecoveryToken)
}

if (jwt.expiresAt?.isBefore(Instant.now()) != false) {
throw InvalidBearerTokenException(ErrorMessages.expiredRecoveryToken)
}
val jwt = jwtDecoder.decode(recoveryToken)
val account = accountService.getAccountByEmail(jwt.subject)

val tokenPasswordHash = jwt.getClaim<String>("passwordHash")
?: throw InvalidBearerTokenException(ErrorMessages.invalidRecoveryToken)
?: throw InvalidBearerTokenException(ErrorMessages.invalidToken)

if (account.password != tokenPasswordHash) {
throw InvalidBearerTokenException(ErrorMessages.expiredRecoveryToken)
throw InvalidBearerTokenException(ErrorMessages.invalidToken)
}

account.password = passwordEncoder.encode(dto.password)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@ object ErrorMessages {

const val invalidCredentials = "invalid credentials"

const val invalidRefreshToken = "invalid refresh token"
const val invalidToken = "invalid token"

const val expiredRefreshToken = "refresh token has expired"

const val invalidRecoveryToken = "invalid password recovery token"

const val expiredRecoveryToken = "password recovery token has expired"
const val expiredToken = "token has expired"

const val noGenerations = "no generations created yet"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import com.epages.restdocs.apispec.HeaderDescriptorWithType
import com.epages.restdocs.apispec.ResourceDocumentation
import com.epages.restdocs.apispec.ResourceDocumentation.headerWithName
import com.fasterxml.jackson.databind.ObjectMapper
import java.time.Instant
import java.time.temporal.ChronoUnit
import java.util.Calendar
import org.hamcrest.Matchers.startsWith
import org.junit.jupiter.api.BeforeAll
Expand All @@ -18,6 +20,10 @@ import org.springframework.restdocs.mockmvc.RestDocumentationRequestBuilders.get
import org.springframework.restdocs.mockmvc.RestDocumentationRequestBuilders.post
import org.springframework.restdocs.payload.JsonFieldType
import org.springframework.security.crypto.password.PasswordEncoder
import org.springframework.security.oauth2.jwt.JwtClaimsSet
import org.springframework.security.oauth2.jwt.JwtDecoder
import org.springframework.security.oauth2.jwt.JwtEncoder
import org.springframework.security.oauth2.jwt.JwtEncoderParameters
import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.post
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.content
Expand All @@ -29,7 +35,6 @@ import pt.up.fe.ni.website.backend.model.Account
import pt.up.fe.ni.website.backend.model.CustomWebsite
import pt.up.fe.ni.website.backend.model.constants.AccountConstants
import pt.up.fe.ni.website.backend.repository.AccountRepository
import pt.up.fe.ni.website.backend.service.ErrorMessages
import pt.up.fe.ni.website.backend.utils.TestUtils
import pt.up.fe.ni.website.backend.utils.ValidationTester
import pt.up.fe.ni.website.backend.utils.annotations.ControllerTest
Expand All @@ -52,6 +57,8 @@ class AuthControllerTest @Autowired constructor(
val repository: AccountRepository,
val mockMvc: MockMvc,
val objectMapper: ObjectMapper,
val jwtEncoder: JwtEncoder,
val jwtDecoder: JwtDecoder,
passwordEncoder: PasswordEncoder
) {
final val testPassword = "testPassword"
Expand Down Expand Up @@ -162,7 +169,7 @@ class AuthControllerTest @Autowired constructor(
)
.andExpectAll(
status().isUnauthorized,
jsonPath("$.errors[0].message").value("invalid refresh token")
jsonPath("$.errors[0].message").value("invalid token")
)
.andDocumentErrorResponse(documentation, hasRequestPayload = true)
}
Expand Down Expand Up @@ -336,7 +343,7 @@ class AuthControllerTest @Autowired constructor(
).andExpectAll(
status().isUnauthorized(),
jsonPath("$.errors.length()").value(1),
jsonPath("$.errors[0].message").value("invalid password recovery token")
jsonPath("$.errors[0].message").value("invalid token")
).andDocumentCustomRequestSchemaErrorResponse(
documentation,
passwordRecoveryPayload,
Expand All @@ -357,6 +364,106 @@ class AuthControllerTest @Autowired constructor(
}.andExpect { status { isUnauthorized() } }
}

@Test
fun `should fail when token is expired`() {
mockMvc.perform(
post("/auth/password/recovery")
.contentType(MediaType.APPLICATION_JSON)
.content(
objectMapper.writeValueAsString(
mapOf(
"email" to testAccount.email
)
)
)
)
.andReturn().response.let { authResponse ->
val token = objectMapper.readTree(authResponse.contentAsString)["recovery_url"].asText()
.removePrefix("$recoverPasswordPage/")
.removeSuffix("/confirm")

val decoded = jwtDecoder.decode(token)
val newClaims = mutableMapOf<String, Any>()
newClaims.putAll(decoded.claims)

val claimsBuilder = JwtClaimsSet
.builder()
.issuer("self")
.issuedAt(Instant.now().minus(2, ChronoUnit.DAYS))
.expiresAt(Instant.now().minus(1, ChronoUnit.DAYS))
.subject(decoded.subject)
.claim("scope", decoded.claims["scope"])

val newToken = jwtEncoder.encode(JwtEncoderParameters.from(claimsBuilder.build())).tokenValue

mockMvc.perform(
post("/auth/password/recovery/{token}/confirm", newToken)
.contentType(MediaType.APPLICATION_JSON)
.content(
objectMapper.writeValueAsString(
mapOf(
"password" to newPassword
)
)
)
).andExpectAll(
status().isUnauthorized(),
jsonPath("$.errors.length()").value(1),
jsonPath("$.errors[0].message").value("token has expired")
)
}
}

@Test
fun `should fail when password hash claim is missing`() {
mockMvc.perform(
post("/auth/password/recovery")
.contentType(MediaType.APPLICATION_JSON)
.content(
objectMapper.writeValueAsString(
mapOf(
"email" to testAccount.email
)
)
)
)
.andReturn().response.let { authResponse ->
val token = objectMapper.readTree(authResponse.contentAsString)["recovery_url"].asText()
.removePrefix("$recoverPasswordPage/")
.removeSuffix("/confirm")

val decoded = jwtDecoder.decode(token)
val newClaims = mutableMapOf<String, Any>()
newClaims.putAll(decoded.claims)

val claimsBuilder = JwtClaimsSet
.builder()
.issuer("self")
.issuedAt(decoded.issuedAt)
.expiresAt(decoded.expiresAt)
.subject(decoded.subject)
.claim("scope", decoded.claims["scope"])

val newToken = jwtEncoder.encode(JwtEncoderParameters.from(claimsBuilder.build())).tokenValue

mockMvc.perform(
post("/auth/password/recovery/{token}/confirm", newToken)
.contentType(MediaType.APPLICATION_JSON)
.content(
objectMapper.writeValueAsString(
mapOf(
"password" to newPassword
)
)
)
).andExpectAll(
status().isUnauthorized(),
jsonPath("$.errors.length()").value(1),
jsonPath("$.errors[0].message").value("invalid token")
)
}
}

@Test
fun `should fail when using recovery token twice`() {
mockMvc.perform(
Expand Down Expand Up @@ -402,7 +509,7 @@ class AuthControllerTest @Autowired constructor(
).andExpectAll(
status().isUnauthorized(),
jsonPath("$.errors.length()").value(1),
jsonPath("$.errors[0].message").value(ErrorMessages.expiredRecoveryToken)
jsonPath("$.errors[0].message").value("invalid token")
).andDocumentCustomRequestSchemaErrorResponse(
documentation,
passwordRecoveryPayload,
Expand Down

0 comments on commit 9a9fa6c

Please sign in to comment.