From 05406ec06975fd6547a9fb637c6fdd17e85ececa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erick=20Dur=C3=A1n?= <23424469+erickduran@users.noreply.github.com> Date: Fri, 16 Dec 2022 17:44:20 -0800 Subject: [PATCH] Add private key cache and remove unused code (#353) --- confidant/services/jwkmanager.py | 54 ++++++------------- .../confidant/services/jwkmanager_test.py | 27 ---------- 2 files changed, 15 insertions(+), 66 deletions(-) diff --git a/confidant/services/jwkmanager.py b/confidant/services/jwkmanager.py index e21309db..9265336b 100644 --- a/confidant/services/jwkmanager.py +++ b/confidant/services/jwkmanager.py @@ -1,10 +1,7 @@ import jwt from jwcrypto import jwk -from typing import Any, Dict, Optional -from hashlib import sha1 -from cryptography.x509 import load_pem_x509_certificate -from cryptography.hazmat.primitives import serialization +from typing import Dict, Optional from confidant.settings import CERTIFICATE_AUTHORITIES, \ DEFAULT_JWT_EXPIRATION_SECONDS, JWT_CACHING_ENABLED from confidant.utils import stats @@ -14,9 +11,8 @@ class JWKManager: def __init__(self) -> None: self._keys = jwk.JWKSet() - self._public_keys = {} self._token_cache = {} - self._payload_cache = {} + self._pem_cache = {} self._load_certificate_authorities() @@ -38,10 +34,21 @@ def set_key(self, kid: str, private_key: str, self._keys.add(key) return kid + def _get_key(self, kid: str): + if kid not in self._pem_cache: + # setting either way to avoid further lookups when response is None + self._pem_cache[kid] = self._keys.get_key(kid) + if self._pem_cache[kid]: + self._pem_cache[kid] = self._pem_cache[kid].export_to_pem( + private_key=True, + password=None + ) + return self._pem_cache[kid] + def get_jwt(self, kid: str, payload: dict, expiration_seconds: int = DEFAULT_JWT_EXPIRATION_SECONDS, algorithm: str = 'RS256') -> str: - key = self._keys.get_key(kid) + key = self._get_key(kid) if not key: raise ValueError('This private key is not stored!') @@ -70,11 +77,10 @@ def get_jwt(self, kid: str, payload: dict, }) with stats.timer('get_jwt.encode'): - # XXX: TODO: cache export_to_pem token = jwt.encode( payload=payload, headers={'kid': kid}, - key=key.export_to_pem(private_key=True, password=None), + key=key, algorithm=algorithm, ) @@ -85,36 +91,6 @@ def get_jwt(self, kid: str, payload: dict, stats.incr('get_jwt.create') return token - def _get_public_key(self, alias: str, certificate: str, - encoding: str = 'utf-8') -> bytes: - if alias not in self._public_keys: - imported_cert = load_pem_x509_certificate( - certificate.encode(encoding) - ) - public_key = imported_cert.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo - ) - self._public_keys[alias] = public_key - return self._public_keys[alias] - - def get_payload(self, certificate: str, token: str, - encoding: str = 'utf-8') -> Dict[str, Any]: - certificate_hash = sha1(certificate.encode('utf-8')).hexdigest() - public_key = self._get_public_key(certificate_hash, certificate, - encoding=encoding) - token_hash = sha1(token.encode('utf-8')).hexdigest() - - if certificate_hash not in self._payload_cache: - self._payload_cache[certificate_hash] = {} - - if token_hash not in self._payload_cache[certificate_hash]: - headers = jwt.get_unverified_header(token) - self._payload_cache[certificate_hash][token_hash] = \ - jwt.decode(token, public_key, algorithms=headers['alg']) - - return self._payload_cache[certificate_hash][token_hash] - def get_jwks(self, key_id: str, algorithm: str = 'RS256') -> Dict[str, str]: key = self._keys.get_key(key_id) if key: diff --git a/tests/unit/confidant/services/jwkmanager_test.py b/tests/unit/confidant/services/jwkmanager_test.py index 9000bf17..2c802543 100644 --- a/tests/unit/confidant/services/jwkmanager_test.py +++ b/tests/unit/confidant/services/jwkmanager_test.py @@ -6,7 +6,6 @@ from unittest.mock import patch, Mock from confidant.services.jwkmanager import jwk_manager -from calendar import timegm def test_set_key(test_key_pair): @@ -126,32 +125,6 @@ def test_get_jwt_raises_no_key_id(test_key_pair, test_jwk_payload): jwk_manager.get_jwt('non-existent', test_jwk_payload) -@patch('jwt.api_jwt.PyJWT._validate_exp', return_value=True) -def test_get_payload(mock_validate, test_key_pair, test_jwk_payload, test_jwt, - test_certificate): - test_private_key = test_key_pair.export_to_pem(private_key=True, - password=None) - jwk_manager.set_key('test-key', test_private_key.decode('utf-8')) - result = jwk_manager.get_payload(test_certificate.decode('utf-8'), - test_jwt) - mocked_date = datetime.datetime( - year=2020, - month=10, - day=10, - hour=0, - minute=0, - second=0, - microsecond=0 - ) - test_jwk_payload.update({ - 'iat': timegm(mocked_date.utctimetuple()), - 'nbf': timegm(mocked_date.utctimetuple()), - 'exp': timegm((mocked_date + - datetime.timedelta(seconds=3600)).utctimetuple()), - }) - assert result == test_jwk_payload - - def test_get_jwks(test_key_pair, test_jwk_payload, test_jwt, test_jwks): test_private_key = test_key_pair.export_to_pem(private_key=True,