From 9a9fa6c0ea13d2bf8b2f962004d7ffa3dbf5403f Mon Sep 17 00:00:00 2001 From: Joaquim Cunha Date: Sat, 12 Aug 2023 07:23:53 +0100 Subject: [PATCH] improved jwt token error handling --- .../backend/controller/ErrorController.kt | 19 +++ .../ni/website/backend/service/AuthService.kt | 25 +--- .../website/backend/service/ErrorMessages.kt | 8 +- .../backend/controller/AuthControllerTest.kt | 115 +++++++++++++++++- 4 files changed, 136 insertions(+), 31 deletions(-) diff --git a/src/main/kotlin/pt/up/fe/ni/website/backend/controller/ErrorController.kt b/src/main/kotlin/pt/up/fe/ni/website/backend/controller/ErrorController.kt index 1546a9d1..8d11256e 100644 --- a/src/main/kotlin/pt/up/fe/ni/website/backend/controller/ErrorController.kt +++ b/src/main/kotlin/pt/up/fe/ni/website/backend/controller/ErrorController.kt @@ -10,6 +10,8 @@ import org.springframework.http.HttpStatus import org.springframework.http.converter.HttpMessageNotReadableException import org.springframework.security.access.AccessDeniedException import org.springframework.security.core.AuthenticationException +import org.springframework.security.oauth2.jwt.BadJwtException +import org.springframework.security.oauth2.jwt.JwtValidationException import org.springframework.web.bind.MethodArgumentNotValidException import org.springframework.web.bind.annotation.ExceptionHandler import org.springframework.web.bind.annotation.RequestMapping @@ -18,6 +20,7 @@ import org.springframework.web.bind.annotation.RestController import org.springframework.web.bind.annotation.RestControllerAdvice import org.springframework.web.multipart.MaxUploadSizeExceededException import pt.up.fe.ni.website.backend.config.Logging +import pt.up.fe.ni.website.backend.service.ErrorMessages data class SimpleError( val message: String, @@ -134,6 +137,22 @@ class ErrorController(private val objectMapper: ObjectMapper) : ErrorController, return wrapSimpleError(e.message ?: "invalid authentication") } + @ExceptionHandler(JwtValidationException::class) + @ResponseStatus(HttpStatus.UNAUTHORIZED) + fun invalidAuthentication(e: JwtValidationException): CustomError { + return if (e.message?.contains("expired") == true) { + wrapSimpleError(ErrorMessages.expiredToken) + } else { + wrapSimpleError(ErrorMessages.invalidToken) + } + } + + @ExceptionHandler(BadJwtException::class) + @ResponseStatus(HttpStatus.UNAUTHORIZED) + fun invalidAuthentication(e: BadJwtException): CustomError { + return wrapSimpleError(ErrorMessages.invalidToken) + } + fun wrapSimpleError(msg: String, param: String? = null, value: Any? = null) = CustomError( mutableListOf(SimpleError(msg, param, value)) ) diff --git a/src/main/kotlin/pt/up/fe/ni/website/backend/service/AuthService.kt b/src/main/kotlin/pt/up/fe/ni/website/backend/service/AuthService.kt index 7aa002d8..dfef1623 100644 --- a/src/main/kotlin/pt/up/fe/ni/website/backend/service/AuthService.kt +++ b/src/main/kotlin/pt/up/fe/ni/website/backend/service/AuthService.kt @@ -45,15 +45,7 @@ class AuthService( } fun refreshAccessToken(refreshToken: String): String { - val jwt = - try { - jwtDecoder.decode(refreshToken) - } catch (e: Exception) { - throw InvalidBearerTokenException(ErrorMessages.invalidRefreshToken) - } - if (jwt.expiresAt?.isBefore(Instant.now()) != false) { - throw InvalidBearerTokenException(ErrorMessages.expiredRefreshToken) - } + val jwt = jwtDecoder.decode(refreshToken) val account = accountService.getAccountByEmail(jwt.subject) return generateAccessToken(account) } @@ -72,23 +64,14 @@ class AuthService( } fun confirmRecoveryToken(recoveryToken: String, dto: PasswordRecoveryConfirmDto): Account { - val jwt = - try { - jwtDecoder.decode(recoveryToken) - } catch (e: Exception) { - throw InvalidBearerTokenException(ErrorMessages.invalidRecoveryToken) - } - - if (jwt.expiresAt?.isBefore(Instant.now()) != false) { - throw InvalidBearerTokenException(ErrorMessages.expiredRecoveryToken) - } + val jwt = jwtDecoder.decode(recoveryToken) val account = accountService.getAccountByEmail(jwt.subject) val tokenPasswordHash = jwt.getClaim("passwordHash") - ?: throw InvalidBearerTokenException(ErrorMessages.invalidRecoveryToken) + ?: throw InvalidBearerTokenException(ErrorMessages.invalidToken) if (account.password != tokenPasswordHash) { - throw InvalidBearerTokenException(ErrorMessages.expiredRecoveryToken) + throw InvalidBearerTokenException(ErrorMessages.invalidToken) } account.password = passwordEncoder.encode(dto.password) diff --git a/src/main/kotlin/pt/up/fe/ni/website/backend/service/ErrorMessages.kt b/src/main/kotlin/pt/up/fe/ni/website/backend/service/ErrorMessages.kt index cd55189d..ea20ceb8 100644 --- a/src/main/kotlin/pt/up/fe/ni/website/backend/service/ErrorMessages.kt +++ b/src/main/kotlin/pt/up/fe/ni/website/backend/service/ErrorMessages.kt @@ -7,13 +7,9 @@ object ErrorMessages { const val invalidCredentials = "invalid credentials" - const val invalidRefreshToken = "invalid refresh token" + const val invalidToken = "invalid token" - const val expiredRefreshToken = "refresh token has expired" - - const val invalidRecoveryToken = "invalid password recovery token" - - const val expiredRecoveryToken = "password recovery token has expired" + const val expiredToken = "token has expired" const val noGenerations = "no generations created yet" diff --git a/src/test/kotlin/pt/up/fe/ni/website/backend/controller/AuthControllerTest.kt b/src/test/kotlin/pt/up/fe/ni/website/backend/controller/AuthControllerTest.kt index 16eb4bc7..8dabdbcd 100644 --- a/src/test/kotlin/pt/up/fe/ni/website/backend/controller/AuthControllerTest.kt +++ b/src/test/kotlin/pt/up/fe/ni/website/backend/controller/AuthControllerTest.kt @@ -4,6 +4,8 @@ import com.epages.restdocs.apispec.HeaderDescriptorWithType import com.epages.restdocs.apispec.ResourceDocumentation import com.epages.restdocs.apispec.ResourceDocumentation.headerWithName import com.fasterxml.jackson.databind.ObjectMapper +import java.time.Instant +import java.time.temporal.ChronoUnit import java.util.Calendar import org.hamcrest.Matchers.startsWith import org.junit.jupiter.api.BeforeAll @@ -18,6 +20,10 @@ import org.springframework.restdocs.mockmvc.RestDocumentationRequestBuilders.get import org.springframework.restdocs.mockmvc.RestDocumentationRequestBuilders.post import org.springframework.restdocs.payload.JsonFieldType import org.springframework.security.crypto.password.PasswordEncoder +import org.springframework.security.oauth2.jwt.JwtClaimsSet +import org.springframework.security.oauth2.jwt.JwtDecoder +import org.springframework.security.oauth2.jwt.JwtEncoder +import org.springframework.security.oauth2.jwt.JwtEncoderParameters import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.post import org.springframework.test.web.servlet.result.MockMvcResultMatchers.content @@ -29,7 +35,6 @@ import pt.up.fe.ni.website.backend.model.Account import pt.up.fe.ni.website.backend.model.CustomWebsite import pt.up.fe.ni.website.backend.model.constants.AccountConstants import pt.up.fe.ni.website.backend.repository.AccountRepository -import pt.up.fe.ni.website.backend.service.ErrorMessages import pt.up.fe.ni.website.backend.utils.TestUtils import pt.up.fe.ni.website.backend.utils.ValidationTester import pt.up.fe.ni.website.backend.utils.annotations.ControllerTest @@ -52,6 +57,8 @@ class AuthControllerTest @Autowired constructor( val repository: AccountRepository, val mockMvc: MockMvc, val objectMapper: ObjectMapper, + val jwtEncoder: JwtEncoder, + val jwtDecoder: JwtDecoder, passwordEncoder: PasswordEncoder ) { final val testPassword = "testPassword" @@ -162,7 +169,7 @@ class AuthControllerTest @Autowired constructor( ) .andExpectAll( status().isUnauthorized, - jsonPath("$.errors[0].message").value("invalid refresh token") + jsonPath("$.errors[0].message").value("invalid token") ) .andDocumentErrorResponse(documentation, hasRequestPayload = true) } @@ -336,7 +343,7 @@ class AuthControllerTest @Autowired constructor( ).andExpectAll( status().isUnauthorized(), jsonPath("$.errors.length()").value(1), - jsonPath("$.errors[0].message").value("invalid password recovery token") + jsonPath("$.errors[0].message").value("invalid token") ).andDocumentCustomRequestSchemaErrorResponse( documentation, passwordRecoveryPayload, @@ -357,6 +364,106 @@ class AuthControllerTest @Autowired constructor( }.andExpect { status { isUnauthorized() } } } + @Test + fun `should fail when token is expired`() { + mockMvc.perform( + post("/auth/password/recovery") + .contentType(MediaType.APPLICATION_JSON) + .content( + objectMapper.writeValueAsString( + mapOf( + "email" to testAccount.email + ) + ) + ) + ) + .andReturn().response.let { authResponse -> + val token = objectMapper.readTree(authResponse.contentAsString)["recovery_url"].asText() + .removePrefix("$recoverPasswordPage/") + .removeSuffix("/confirm") + + val decoded = jwtDecoder.decode(token) + val newClaims = mutableMapOf() + newClaims.putAll(decoded.claims) + + val claimsBuilder = JwtClaimsSet + .builder() + .issuer("self") + .issuedAt(Instant.now().minus(2, ChronoUnit.DAYS)) + .expiresAt(Instant.now().minus(1, ChronoUnit.DAYS)) + .subject(decoded.subject) + .claim("scope", decoded.claims["scope"]) + + val newToken = jwtEncoder.encode(JwtEncoderParameters.from(claimsBuilder.build())).tokenValue + + mockMvc.perform( + post("/auth/password/recovery/{token}/confirm", newToken) + .contentType(MediaType.APPLICATION_JSON) + .content( + objectMapper.writeValueAsString( + mapOf( + "password" to newPassword + ) + ) + ) + ).andExpectAll( + status().isUnauthorized(), + jsonPath("$.errors.length()").value(1), + jsonPath("$.errors[0].message").value("token has expired") + ) + } + } + + @Test + fun `should fail when password hash claim is missing`() { + mockMvc.perform( + post("/auth/password/recovery") + .contentType(MediaType.APPLICATION_JSON) + .content( + objectMapper.writeValueAsString( + mapOf( + "email" to testAccount.email + ) + ) + ) + ) + .andReturn().response.let { authResponse -> + val token = objectMapper.readTree(authResponse.contentAsString)["recovery_url"].asText() + .removePrefix("$recoverPasswordPage/") + .removeSuffix("/confirm") + + val decoded = jwtDecoder.decode(token) + val newClaims = mutableMapOf() + newClaims.putAll(decoded.claims) + + val claimsBuilder = JwtClaimsSet + .builder() + .issuer("self") + .issuedAt(decoded.issuedAt) + .expiresAt(decoded.expiresAt) + .subject(decoded.subject) + .claim("scope", decoded.claims["scope"]) + + val newToken = jwtEncoder.encode(JwtEncoderParameters.from(claimsBuilder.build())).tokenValue + + mockMvc.perform( + post("/auth/password/recovery/{token}/confirm", newToken) + .contentType(MediaType.APPLICATION_JSON) + .content( + objectMapper.writeValueAsString( + mapOf( + "password" to newPassword + ) + ) + ) + ).andExpectAll( + status().isUnauthorized(), + jsonPath("$.errors.length()").value(1), + jsonPath("$.errors[0].message").value("invalid token") + ) + } + } + @Test fun `should fail when using recovery token twice`() { mockMvc.perform( @@ -402,7 +509,7 @@ class AuthControllerTest @Autowired constructor( ).andExpectAll( status().isUnauthorized(), jsonPath("$.errors.length()").value(1), - jsonPath("$.errors[0].message").value(ErrorMessages.expiredRecoveryToken) + jsonPath("$.errors[0].message").value("invalid token") ).andDocumentCustomRequestSchemaErrorResponse( documentation, passwordRecoveryPayload,