Skip to content
This repository has been archived by the owner on Dec 4, 2023. It is now read-only.

Commit

Permalink
Merge pull request #584 from microsoft/trboehre/jwt-cert
Browse files Browse the repository at this point in the history
Merge pull request #567 from microsoft/trboehre/expiredcert
  • Loading branch information
tracyboehrer authored Jun 8, 2020
2 parents 934db11 + 2c5df81 commit 3aba19f
Show file tree
Hide file tree
Showing 10 changed files with 445 additions and 118 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.bot.connector.authentication;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkException;
import com.auth0.jwk.SigningKeyNotFoundException;
import com.auth0.jwk.UrlJwkProvider;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.io.IOException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Maintains a cache of OpenID metadata keys.
*/
class CachingOpenIdMetadata implements OpenIdMetadata {
private static final Logger LOGGER = LoggerFactory.getLogger(CachingOpenIdMetadata.class);
private static final int CACHE_DAYS = 5;

private String url;
private long lastUpdated;
private ObjectMapper mapper;
private Map<String, Jwk> keyCache = new HashMap<>();
private final Object sync = new Object();

/**
* Constructs a OpenIdMetaData cache for a url.
*
* @param withUrl The url.
*/
CachingOpenIdMetadata(String withUrl) {
url = withUrl;
mapper = new ObjectMapper().findAndRegisterModules();
}

/**
* Gets a openid key.
*
* <p>
* Note: This could trigger a cache refresh, which will incur network calls.
* </p>
*
* @param keyId The JWT key.
* @return The cached key.
*/
@Override
public OpenIdMetadataKey getKey(String keyId) {
synchronized (sync) {
// If keys are more than 5 days old, refresh them
if (lastUpdated < System.currentTimeMillis() - Duration.ofDays(CACHE_DAYS).toMillis()) {
refreshCache();
}

// Search the cache even if we failed to refresh
return findKey(keyId);
}
}

private void refreshCache() {
keyCache.clear();

try {
URL openIdUrl = new URL(this.url);
HashMap<String, String> openIdConf =
this.mapper.readValue(openIdUrl, new TypeReference<HashMap<String, Object>>() {
});
URL keysUrl = new URL(openIdConf.get("jwks_uri"));
lastUpdated = System.currentTimeMillis();
UrlJwkProvider provider = new UrlJwkProvider(keysUrl);
keyCache = provider.getAll().stream().collect(Collectors.toMap(Jwk::getId, jwk -> jwk));
} catch (IOException e) {
LOGGER.error(String.format("Failed to load openID config: %s", e.getMessage()));
lastUpdated = 0;
} catch (SigningKeyNotFoundException keyexception) {
LOGGER.error("refreshCache", keyexception);
lastUpdated = 0;
}
}

@SuppressWarnings("unchecked")
private OpenIdMetadataKey findKey(String keyId) {
if (!keyCache.containsKey(keyId)) {
LOGGER.warn("findKey: keyId " + keyId + " doesn't exist.");
return null;
}

try {
Jwk jwk = keyCache.get(keyId);
OpenIdMetadataKey key = new OpenIdMetadataKey();
key.key = (RSAPublicKey) jwk.getPublicKey();
key.endorsements = (List<String>) jwk.getAdditionalAttributes().get("endorsements");
key.certificateChain = jwk.getCertificateChain();
return key;
} catch (JwkException e) {
String errorDescription = String.format("Failed to load keys: %s", e.getMessage());
LOGGER.warn(errorDescription);
}
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.microsoft.bot.connector.authentication;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
* Maintains a cache of OpenIdMetadata objects.
*/
public class CachingOpenIdMetadataResolver implements OpenIdMetadataResolver {
private static final ConcurrentMap<String, CachingOpenIdMetadata> OPENID_METADATA_CACHE =
new ConcurrentHashMap<>();

/**
* Gets the OpenIdMetadata object for the specified key.
* @param metadataUrl The key
* @return The OpenIdMetadata object. If the key is not found, an new OpenIdMetadata
* object is created.
*/
@Override
public OpenIdMetadata get(String metadataUrl) {
return OPENID_METADATA_CACHE
.computeIfAbsent(metadataUrl, key -> new CachingOpenIdMetadata(metadataUrl));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ private EmulatorValidation() {
* TO BOT FROM EMULATOR: Token validation parameters when connecting to a
* channel.
*/
private static final TokenValidationParameters TOKENVALIDATIONPARAMETERS =
public static final TokenValidationParameters TOKENVALIDATIONPARAMETERS =
new TokenValidationParameters() {
{
this.validateIssuer = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public final class GovernmentChannelValidation {
* TO BOT FROM GOVERNMENT CHANNEL: Token validation parameters when connecting
* to a bot.
*/
private static final TokenValidationParameters TOKENVALIDATIONPARAMETERS =
public static final TokenValidationParameters TOKENVALIDATIONPARAMETERS =
new TokenValidationParameters() {
{
this.validateIssuer = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,28 @@
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.Verification;
import com.microsoft.bot.connector.ExecutorFactory;
import java.io.ByteArrayInputStream;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Date;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
* Extracts relevant data from JWT Tokens.
*/
public class JwtTokenExtractor {
private static final Logger LOGGER = LoggerFactory.getLogger(OpenIdMetadata.class);
private static final ConcurrentMap<String, OpenIdMetadata> OPENID_METADATA_CACHE =
new ConcurrentHashMap<>();
private static final Logger LOGGER = LoggerFactory.getLogger(CachingOpenIdMetadata.class);

private TokenValidationParameters tokenValidationParameters;
private List<String> allowedSigningAlgorithms;
private OpenIdMetadataResolver openIdMetadataResolver;
private OpenIdMetadata openIdMetadata;

/**
Expand All @@ -43,13 +45,18 @@ public JwtTokenExtractor(
String withMetadataUrl,
List<String> withAllowedSigningAlgorithms
) {

this.tokenValidationParameters =
new TokenValidationParameters(withTokenValidationParameters);
this.tokenValidationParameters.requireSignedTokens = true;
this.allowedSigningAlgorithms = withAllowedSigningAlgorithms;
this.openIdMetadata = OPENID_METADATA_CACHE
.computeIfAbsent(withMetadataUrl, key -> new OpenIdMetadata(withMetadataUrl));

if (tokenValidationParameters.issuerSigningKeyResolver == null) {
this.openIdMetadataResolver = new CachingOpenIdMetadataResolver();
} else {
this.openIdMetadataResolver = tokenValidationParameters.issuerSigningKeyResolver;
}

this.openIdMetadata = this.openIdMetadataResolver.get(withMetadataUrl);
}

/**
Expand Down Expand Up @@ -143,13 +150,27 @@ private CompletableFuture<ClaimsIdentity> validateToken(
try {
verification.build().verify(token);

// If specified, validate the signing certificate.
if (
tokenValidationParameters.validateIssuerSigningKey
&& key.certificateChain != null
&& key.certificateChain.size() > 0
) {
// Note that decodeCertificate will return null if the cert could not
// be decoded. This would likely be the case if it were in an unexpected
// encoding. Going to err on the side of ignoring this check.
// May want to reconsider this and throw on null cert.
X509Certificate cert = decodeCertificate(key.certificateChain.get(0));
if (cert != null && !isCertValid(cert)) {
throw new JWTVerificationException("Signing certificate is not valid");
}
}

// Note: On the Emulator Code Path, the endorsements collection is null so the
// validation code
// below won't run. This is normal.
// validation code below won't run. This is normal.
if (key.endorsements != null) {
// Validate Channel / Token Endorsements. For this, the channelID present on the
// Activity
// needs to be matched by an endorsement.
// Activity needs to be matched by an endorsement.
boolean isEndorsed =
EndorsementsValidator.validate(channelId, key.endorsements);
if (!isEndorsed) {
Expand All @@ -162,8 +183,7 @@ private CompletableFuture<ClaimsIdentity> validateToken(
}

// Verify that additional endorsements are satisfied. If no additional
// endorsements are expected,
// the requirement is satisfied as well
// endorsements are expected, the requirement is satisfied as well
boolean additionalEndorsementsSatisfied = requiredEndorsements.stream()
.allMatch(
(endorsement) -> EndorsementsValidator
Expand Down Expand Up @@ -195,4 +215,22 @@ private CompletableFuture<ClaimsIdentity> validateToken(
}
}, ExecutorFactory.getExecutor());
}

private X509Certificate decodeCertificate(String certStr) {
try {
byte[] decoded = Base64.getDecoder().decode(certStr);
return (X509Certificate) CertificateFactory
.getInstance("X.509").generateCertificate(new ByteArrayInputStream(decoded));
} catch (Throwable t) {
return null;
}
}

private boolean isCertValid(X509Certificate cert) {
long now = new Date().getTime();
long clockskew = tokenValidationParameters.clockSkew.toMillis();
long startValid = cert.getNotBefore().getTime() - clockskew;
long endValid = cert.getNotAfter().getTime() + clockskew;
return now >= startValid && now <= endValid;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,108 +3,15 @@

package com.microsoft.bot.connector.authentication;

import com.auth0.jwk.Jwk;
import com.auth0.jwk.JwkException;
import com.auth0.jwk.SigningKeyNotFoundException;
import com.auth0.jwk.UrlJwkProvider;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.io.IOException;
import java.net.URL;
import java.security.interfaces.RSAPublicKey;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Maintains a cache of OpenID metadata keys.
* Fetches Jwk data.
*/
class OpenIdMetadata {
private static final Logger LOGGER = LoggerFactory.getLogger(OpenIdMetadata.class);
private static final int CACHE_DAYS = 5;

private String url;
private long lastUpdated;
private ObjectMapper mapper;
private Map<String, Jwk> keyCache = new HashMap<>();
private final Object sync = new Object();

/**
* Constructs a OpenIdMetaData cache for a url.
*
* @param withUrl The url.
*/
OpenIdMetadata(String withUrl) {
url = withUrl;
mapper = new ObjectMapper().findAndRegisterModules();
}
public interface OpenIdMetadata {

/**
* Gets a openid key.
*
* <p>
* Note: This could trigger a cache refresh, which will incur network calls.
* </p>
*
* @param keyId The JWT key.
* @return The cached key.
* Returns the partial Jwk data for a key.
* @param keyId The key id.
* @return The Jwk data.
*/
public OpenIdMetadataKey getKey(String keyId) {
synchronized (sync) {
// If keys are more than 5 days old, refresh them
if (lastUpdated < System.currentTimeMillis() - Duration.ofDays(CACHE_DAYS).toMillis()) {
refreshCache();
}

// Search the cache even if we failed to refresh
return findKey(keyId);
}
}

private void refreshCache() {
keyCache.clear();

try {
URL openIdUrl = new URL(this.url);
HashMap<String, String> openIdConf =
this.mapper.readValue(openIdUrl, new TypeReference<HashMap<String, Object>>() {
});
URL keysUrl = new URL(openIdConf.get("jwks_uri"));
lastUpdated = System.currentTimeMillis();
UrlJwkProvider provider = new UrlJwkProvider(keysUrl);
keyCache = provider.getAll().stream().collect(Collectors.toMap(Jwk::getId, jwk -> jwk));
} catch (IOException e) {
LOGGER.error(String.format("Failed to load openID config: %s", e.getMessage()));
lastUpdated = 0;
} catch (SigningKeyNotFoundException keyexception) {
LOGGER.error("refreshCache", keyexception);
lastUpdated = 0;
}
}

@SuppressWarnings("unchecked")
private OpenIdMetadataKey findKey(String keyId) {
if (!keyCache.containsKey(keyId)) {
LOGGER.warn("findKey: keyId " + keyId + " doesn't exist.");
return null;
}

try {
Jwk jwk = keyCache.get(keyId);
OpenIdMetadataKey key = new OpenIdMetadataKey();
key.key = (RSAPublicKey) jwk.getPublicKey();
key.endorsements = (List<String>) jwk.getAdditionalAttributes().get("endorsements");
return key;
} catch (JwkException e) {
String errorDescription = String.format("Failed to load keys: %s", e.getMessage());
LOGGER.warn(errorDescription);
}
return null;
}
OpenIdMetadataKey getKey(String keyId);
}
Loading

0 comments on commit 3aba19f

Please sign in to comment.