Skip to content

Commit

Permalink
Added debug information for expired jwt (#214)
Browse files Browse the repository at this point in the history
* Enhance JWT verification logs with detailed claim data

Added detailed claim data to logs on JWT expiration and not-before time checks. This improves the ability to diagnose issues by providing comprehensive context in error messages.

* Add Mockito dependencies and JWTClaimsSetVerifierWithLogsTest

This commit adds Mockito dependencies to the pom.xml file to facilitate mocking in unit tests. It also introduces the JWTClaimsSetVerifierWithLogsTest class to test JWT claim set verification, ensuring proper handling of expired and not-before JWT conditions.

* Add BadJOSEException handling for token processing

This commit introduces exception handling for BadJOSEException across the codebase, ensuring that invalid JWT tokens are appropriately handled. The changes include method signatures updates to propagate the exception and modifications in various services, controllers, and tests to handle the exception correctly and provide proper feedback in case of an error.

* Add MDC logging to JWTClaimsSetVerifierWithLogs

Introduce MDC logging for subject, issue time, and token ID in JWTClaimsSetVerifierWithLogs. This enhancement allows for better traceability and debugging by including these details in the log context.

* Enhance logging and error handling for token exchange and JWT

Added detailed trace ID and span ID headers in error responses across TokenExchangeController and JWTAuthenticationFilter. These changes improve debuggability by providing clearer error context and extended log information.

* Refactor SecretServerClient to support additional headers

Added support for passing additional headers to `getSecret` method in `SecretServerClient` and its implementations. Updated method signatures and internal logic to accommodate the new parameter, facilitating enhanced customization and control over secret retrieval.

* Add issued time to JWT claims in test cases

This ensures that the JWT claims set contains an issued time, which is necessary for some verifications. The additional issued time makes the test cases more comprehensive and accurate.

---------

Co-authored-by: marcelmeyer <[email protected]>
  • Loading branch information
Mme-adorsys and mme-flendly authored Sep 27, 2024
1 parent 6da662a commit 37c6aa5
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 67 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package de.adorsys.sts.secretserver;

import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jwt.JWTClaimsSet;
import dasniko.testcontainers.keycloak.KeycloakContainer;
import de.adorsys.sts.keymanagement.service.DecryptionService;
Expand Down Expand Up @@ -99,7 +100,7 @@ static void afterAll() {
}

@Test
void shouldReturnTheSameSecretForSameUser() {
void shouldReturnTheSameSecretForSameUser() throws BadJOSEException {
String firstSecret = getDecryptedSecret(USERNAME_ONE, PASSWORD_ONE);
String secondSecret = getDecryptedSecret(USERNAME_ONE, PASSWORD_ONE);

Expand All @@ -115,7 +116,7 @@ void shouldReturnDifferentSecretsForDifferentUsers() throws Exception {
}

@Test
void shouldNotReturnTheSameTokenForSameUser() throws Exception {
void shouldNotReturnTheSameTokenForSameUser() {
TokenResponse firstTokenResponse = getSecretServerToken(USERNAME_ONE, PASSWORD_ONE);
assertThat(firstTokenResponse.getAccess_token(), is(notNullValue()));

Expand All @@ -126,7 +127,7 @@ void shouldNotReturnTheSameTokenForSameUser() throws Exception {
}

@Test
void shouldNotGetSecretForInvalidAccessToken() throws Exception {
void shouldNotGetSecretForInvalidAccessToken() {
final String invalidAccessToken = "eyJhbGciOiJSUzI1NiIsInR5cCIgOiAiSldUIiwia2lkIiA6ICJvVjU2Uk9namthbTVzUmVqdjF6b1JVNmY" +
"1R3YtUGRTdjN2b1ZfRVY5MmxnIn0.eyJqdGkiOiI5NWY2MzQ4NC04MTk2LTQ2NzYtYjI4Ni1lYjY4YTFmOTZmYTAiLCJleHAiOjE1N" +
"TUwNDg5MzIsIm5iZiI6MCwiaWF0IjoxNTU1MDQ4NjMyLCJpc3MiOiJodHRwOi8vbG9jYWxob3N0OjMyODU0L2F1dGgvcmVhbG1zL21" +
Expand All @@ -150,7 +151,7 @@ void shouldNotGetSecretForInvalidAccessToken() throws Exception {
}

@Test
void shouldNotGetSecretForFakeAccessToken() throws Exception {
void shouldNotGetSecretForFakeAccessToken() {
final String fakeAccessToken = "my fake access token";

catchException(() -> client.exchangeToken("/secret-server/token-exchange", MOPED_CLIENT_AUDIENCE, fakeAccessToken));
Expand All @@ -162,7 +163,7 @@ void shouldNotGetSecretForFakeAccessToken() throws Exception {
}

@Test
void shouldGetEmptySecretsForUnknownAudience() {
void shouldGetEmptySecretsForUnknownAudience() throws BadJOSEException {
Authentication.AuthenticationToken authToken = authentication.login(USERNAME_ONE, PASSWORD_ONE);

TokenResponse secretServerToken = client.exchangeToken("/secret-server/token-exchange", "unknown audience", authToken.getAccessToken());
Expand All @@ -171,17 +172,17 @@ void shouldGetEmptySecretsForUnknownAudience() {
assertThat(secrets.size(), is(equalTo(0)));
}

private String getDecryptedSecret(String username, String password) {
private String getDecryptedSecret(String username, String password) throws BadJOSEException {
TokenResponse secretServerToken = getSecretServerToken(username, password);
return extractSecretFromToken(secretServerToken.getAccess_token());
}

private String extractSecretFromToken(String secretServerAccessToken) {
private String extractSecretFromToken(String secretServerAccessToken) throws BadJOSEException {
Map<String, String> secrets = extractSecretsFromToken(secretServerAccessToken);
return decryptionService.decrypt(secrets.get(MOPED_CLIENT_AUDIENCE));
}

private Map<String, String> extractSecretsFromToken(String secretServerAccessToken) {
private Map<String, String> extractSecretsFromToken(String secretServerAccessToken) throws BadJOSEException {
BearerToken exchangedToken = bearerTokenValidator.extract(secretServerAccessToken);
JWTClaimsSet claims = exchangedToken.getClaims();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;

import java.util.Map;
import java.util.concurrent.TimeUnit;

public class CachingSecretServerClient implements SecretServerClient {
Expand All @@ -27,7 +28,7 @@ public String load(String token) throws Exception {
}

@Override
public String getSecret(String token) {
public String getSecret(String token, Map<String, String> headers) {
return secrets.getUnchecked(token);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;

public class LoggingSecretServerClient implements SecretServerClient {
private static final Logger logger = LoggerFactory.getLogger(LoggingSecretServerClient.class);

Expand All @@ -13,13 +15,14 @@ public LoggingSecretServerClient(SecretServerClient secretServerClient) {
}

@Override
public String getSecret(String token) {
public String getSecret(String token, Map<String, String> additionalHeaders) {
if(logger.isTraceEnabled()) logger.trace("get secret for token start...");

String secret = decoratedSecretServerClient.getSecret(token);
String secret = decoratedSecretServerClient.getSecret(token, additionalHeaders);

if(logger.isTraceEnabled()) logger.trace("get secret for token finished.");

return secret;
}

}
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
package de.adorsys.sts.secret;

import java.util.HashMap;
import java.util.Map;

public interface SecretServerClient {

/**
* Provides the decrypted BASE64 encoded secret for the user using the specified token.
*/
String getSecret(String token);
default String getSecret(String token) {
return getSecret(token, new HashMap<>());
}

/**
* Provides the decrypted BASE64 encoded secret for the user using the specified token.
*/
String getSecret(String token, Map<String, String> additionalHeaders);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package de.adorsys.sts.filter;

import com.nimbusds.jose.proc.BadJOSEException;
import de.adorsys.sts.token.authentication.TokenAuthenticationService;
import jakarta.annotation.Nonnull;
import jakarta.servlet.FilterChain;
Expand Down Expand Up @@ -31,7 +32,14 @@ public void doFilterInternal(@Nonnull HttpServletRequest request, @Nonnull HttpS
if (logger.isDebugEnabled())
logger.debug("Authentication is null. Try to get authentication from request...");

authentication = tokenAuthenticationService.getAuthentication(request);
try {
authentication = tokenAuthenticationService.getAuthentication(request);
} catch (BadJOSEException e) {
response.setHeader("X-B3-TraceId", request.getHeader("X-B3-TraceId"));
response.setHeader("X-B3-SpanId", request.getHeader("X-B3-SpanId"));
response.sendError(HttpServletResponse.SC_FORBIDDEN, "Invalid token - Token expired");
}

SecurityContextHolder.getContext().setAuthentication(authentication);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package de.adorsys.sts.token.authentication;


import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jwt.JWTClaimsSet;
import de.adorsys.sts.tokenauth.BearerToken;
import de.adorsys.sts.tokenauth.BearerTokenValidator;
Expand All @@ -14,6 +14,7 @@
import org.springframework.stereotype.Service;

import jakarta.servlet.http.HttpServletRequest;

import java.util.ArrayList;
import java.util.List;

Expand All @@ -31,16 +32,18 @@ public TokenAuthenticationService(BearerTokenValidator bearerTokenValidator) {
this.bearerTokenValidator = bearerTokenValidator;
}

public Authentication getAuthentication(HttpServletRequest request) {
public Authentication getAuthentication(HttpServletRequest request) throws BadJOSEException {
String headerValue = request.getHeader(HEADER_KEY);
if(StringUtils.isBlank(headerValue)) {
if(logger.isDebugEnabled()) logger.debug("Header value '{}' is blank.", HEADER_KEY);
if (StringUtils.isBlank(headerValue)) {
if (logger.isDebugEnabled())
logger.debug("Header value '{}' is blank.", HEADER_KEY);
return null;
}

// Accepts only Bearer token
if(!StringUtils.startsWithIgnoreCase(headerValue, TOKEN_PREFIX)) {
if(logger.isDebugEnabled()) logger.debug("Header value does not start with '{}'.", TOKEN_PREFIX);
if (!StringUtils.startsWithIgnoreCase(headerValue, TOKEN_PREFIX)) {
if (logger.isDebugEnabled())
logger.debug("Header value does not start with '{}'.", TOKEN_PREFIX);
return null;
}

Expand All @@ -50,7 +53,8 @@ public Authentication getAuthentication(HttpServletRequest request) {
BearerToken bearerToken = bearerTokenValidator.extract(strippedToken);

if (!bearerToken.isValid()) {
if(logger.isDebugEnabled()) logger.debug("Token is not valid.");
if (logger.isDebugEnabled())
logger.debug("Token is not valid.");
return null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package de.adorsys.sts.token.tokenexchange.server;

import com.nimbusds.jose.proc.BadJOSEException;
import de.adorsys.sts.ResponseUtils;
import de.adorsys.sts.token.InvalidParameterException;
import de.adorsys.sts.token.MissingParameterException;
Expand All @@ -17,6 +18,7 @@
import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.ModelAttribute;
Expand All @@ -29,13 +31,14 @@ public class TokenExchangeController {

private final TokenExchangeService tokenExchangeService;

@PostMapping(consumes = {MediaType.APPLICATION_FORM_URLENCODED_VALUE}, produces = {MediaType.APPLICATION_JSON_VALUE})
@PostMapping(consumes = { MediaType.APPLICATION_FORM_URLENCODED_VALUE }, produces = { MediaType.APPLICATION_JSON_VALUE })
@Operation(summary = "Exchange Token", description = "Create an access or refresh token given a valide subject token.", responses = {
@ApiResponse(responseCode = "200", description = "Ok", content = @Content(mediaType = "application/json", schema = @Schema(implementation = TokenResponse.class))),
@ApiResponse(responseCode = "400", description = "Bad request", headers = @Header(name = "error", description = "invalid request"))
})
public ResponseEntity<Object> tokenExchange(@RequestBody @ModelAttribute TokenRequestForm tokenRequestForm, HttpServletRequest servletRequest) {
if (log.isTraceEnabled()) log.trace("POST tokenExchange started...");
@ApiResponse(responseCode = "400", description = "Bad request", headers = @Header(name = "error", description = "invalid request")) })
public ResponseEntity<Object> tokenExchange(@RequestBody @ModelAttribute TokenRequestForm tokenRequestForm,
HttpServletRequest servletRequest) {
if (log.isTraceEnabled())
log.trace("POST tokenExchange started...");

TokenExchangeRequest tokenExchange = getTokenExchangeRequest(tokenRequestForm, servletRequest);

Expand All @@ -53,23 +56,21 @@ public ResponseEntity<Object> tokenExchange(@RequestBody @ModelAttribute TokenRe
errorMessage = e.getMessage();
ResponseEntity<Object> errorData = ResponseUtils.invalidParam(e.getMessage());
return ResponseEntity.badRequest().body(errorData);
} catch (BadJOSEException e) {
return ResponseEntity.status(HttpStatus.FORBIDDEN).header("source", "sts")
.header("X-B3-TraceId", servletRequest.getHeader("X-B3-TraceId"))
.header("X-B3-SpanId", servletRequest.getHeader("X-B3-SpanId")).body(e.getMessage());
} finally {
if (log.isTraceEnabled()) log.trace("POST tokenExchange finished: {}", errorMessage);
if (log.isTraceEnabled())
log.trace("POST tokenExchange finished: {}", errorMessage);
}
}

private static TokenExchangeRequest getTokenExchangeRequest(TokenRequestForm tokenRequestForm, HttpServletRequest servletRequest) {
return TokenExchangeRequest.builder()
.grantType(tokenRequestForm.getGrantType())
.resources(tokenRequestForm.getResources())
.subjectToken(tokenRequestForm.getSubjectToken())
.subjectTokenType(tokenRequestForm.getSubjectTokenType())
.actorToken(tokenRequestForm.getActorToken())
.actorTokenType(tokenRequestForm.getActorTokenType())
.issuer(ResponseUtils.getIssuer(servletRequest))
.scope(tokenRequestForm.getScope())
.requestedTokenType(tokenRequestForm.getRequestedTokenType())
.audiences(tokenRequestForm.getAudiences())
.build();
return TokenExchangeRequest.builder().grantType(tokenRequestForm.getGrantType()).resources(tokenRequestForm.getResources())
.subjectToken(tokenRequestForm.getSubjectToken()).subjectTokenType(tokenRequestForm.getSubjectTokenType())
.actorToken(tokenRequestForm.getActorToken()).actorTokenType(tokenRequestForm.getActorTokenType())
.issuer(ResponseUtils.getIssuer(servletRequest)).scope(tokenRequestForm.getScope())
.requestedTokenType(tokenRequestForm.getRequestedTokenType()).audiences(tokenRequestForm.getAudiences()).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public BearerTokenValidator(AuthServersProvider authServersProvider, Clock clock
this.clock = clock;
}

public BearerToken extract(String token) {
public BearerToken extract(String token) throws BadJOSEException {
Optional<JWTClaimsSet> jwtClaimsSet = extractClaims(token);
if(jwtClaimsSet.isPresent()) {
List<String> roles = extractRoles(jwtClaimsSet.get());
Expand Down Expand Up @@ -84,7 +84,7 @@ protected void onErrorWhileExtractClaims(String token, Throwable e) {
logger.error("token parse exception");
}

private Optional<JWTClaimsSet> extractClaims(String token) {
private Optional<JWTClaimsSet> extractClaims(String token) throws BadJOSEException {
Optional<JWTClaimsSet> jwtClaimsSet = Optional.empty();

if(token == null) {
Expand Down Expand Up @@ -124,7 +124,7 @@ private Optional<JWTClaimsSet> extractClaims(String token) {
JWTClaimsSet jwtClaims = jwtProcessor.process(signedJWT, context);

jwtClaimsSet = Optional.of(jwtClaims);
} catch (ParseException | BadJOSEException | JOSEException e) {
} catch (ParseException | JOSEException e) {
onErrorWhileExtractClaims(token, e);
return jwtClaimsSet;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import lombok.RequiredArgsConstructor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

import java.time.Clock;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;

@RequiredArgsConstructor
Expand All @@ -28,22 +30,22 @@ public void verify(JWTClaimsSet claimsSet, SecurityContext context) throws BadJW
final Date now = Date.from(clock.instant());

final Date exp = claimsSet.getExpirationTime();

Map<String, Object> claimSet = claimsSet.toPayload().toJSONObject();
MDC.put("sub", claimsSet.getSubject());
MDC.put("iss-at", claimsSet.getIssueTime().toInstant().toString());
MDC.put("token-id", claimsSet.getJWTID());

if (exp != null && !DateUtils.isAfter(exp, now, DEFAULT_MAX_CLOCK_SKEW_SECONDS)) {
String msg = "Expired JWT";
String msg = "Expired JWT - expiration time claim (exp) is not after the current time";
logger.error("{}: expiration time: {} now: {}", msg, exp, now);
logger.error("JWT claims: {}", claimSet);

throw new BadJWTException(msg);
}

final Date nbf = claimsSet.getNotBeforeTime();

if (nbf != null && !DateUtils.isBefore(nbf, now, DEFAULT_MAX_CLOCK_SKEW_SECONDS)) {
String msg = "JWT before use time";
String msg = "JWT before use time- not before claim (nbf) is after the current time";
logger.error("{}: not before time: {} now: {}", msg, nbf, now);
logger.error("JWT claims: {}", claimSet);
throw new BadJWTException(msg);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void verify() {
@Test
public void testVerify_throwsBadJWTException_whenJWTIsExpired() {
Date exp = new Date(System.currentTimeMillis() - 60000);
JWTClaimsSet claimsSet = new JWTClaimsSet.Builder().expirationTime(exp).build();
JWTClaimsSet claimsSet = new JWTClaimsSet.Builder().expirationTime(exp).issueTime(new Date()).build();
when(clock.instant()).thenReturn(Instant.now());
assertThrows(BadJWTException.class, () -> {
underTest.verify(claimsSet, null);
Expand All @@ -47,7 +47,7 @@ public void testVerify_throwsBadJWTException_whenJWTIsExpired() {
@Test
public void testVerify_throwsBadJWTException_whenJWTIsNotBeforeNow() {
Date nbf = new Date(System.currentTimeMillis() + 61000);
JWTClaimsSet claimsSet = new JWTClaimsSet.Builder().notBeforeTime(nbf).build();
JWTClaimsSet claimsSet = new JWTClaimsSet.Builder().notBeforeTime(nbf).issueTime(new Date()).build();
when(clock.instant()).thenReturn(Instant.now());
assertThrows(BadJWTException.class, () -> {
underTest.verify(claimsSet, null);
Expand Down
Loading

0 comments on commit 37c6aa5

Please sign in to comment.