Skip to content

Commit

Permalink
Add private key cache and remove unused code (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
erickduran authored Dec 17, 2022
1 parent d2b51e8 commit 05406ec
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 66 deletions.
54 changes: 15 additions & 39 deletions confidant/services/jwkmanager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import jwt

from jwcrypto import jwk
from typing import Any, Dict, Optional
from hashlib import sha1
from cryptography.x509 import load_pem_x509_certificate
from cryptography.hazmat.primitives import serialization
from typing import Dict, Optional
from confidant.settings import CERTIFICATE_AUTHORITIES, \
DEFAULT_JWT_EXPIRATION_SECONDS, JWT_CACHING_ENABLED
from confidant.utils import stats
Expand All @@ -14,9 +11,8 @@
class JWKManager:
def __init__(self) -> None:
self._keys = jwk.JWKSet()
self._public_keys = {}
self._token_cache = {}
self._payload_cache = {}
self._pem_cache = {}

self._load_certificate_authorities()

Expand All @@ -38,10 +34,21 @@ def set_key(self, kid: str, private_key: str,
self._keys.add(key)
return kid

def _get_key(self, kid: str):
if kid not in self._pem_cache:
# setting either way to avoid further lookups when response is None
self._pem_cache[kid] = self._keys.get_key(kid)
if self._pem_cache[kid]:
self._pem_cache[kid] = self._pem_cache[kid].export_to_pem(
private_key=True,
password=None
)
return self._pem_cache[kid]

def get_jwt(self, kid: str, payload: dict,
expiration_seconds: int = DEFAULT_JWT_EXPIRATION_SECONDS,
algorithm: str = 'RS256') -> str:
key = self._keys.get_key(kid)
key = self._get_key(kid)
if not key:
raise ValueError('This private key is not stored!')

Expand Down Expand Up @@ -70,11 +77,10 @@ def get_jwt(self, kid: str, payload: dict,
})

with stats.timer('get_jwt.encode'):
# XXX: TODO: cache export_to_pem
token = jwt.encode(
payload=payload,
headers={'kid': kid},
key=key.export_to_pem(private_key=True, password=None),
key=key,
algorithm=algorithm,
)

Expand All @@ -85,36 +91,6 @@ def get_jwt(self, kid: str, payload: dict,
stats.incr('get_jwt.create')
return token

def _get_public_key(self, alias: str, certificate: str,
encoding: str = 'utf-8') -> bytes:
if alias not in self._public_keys:
imported_cert = load_pem_x509_certificate(
certificate.encode(encoding)
)
public_key = imported_cert.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
self._public_keys[alias] = public_key
return self._public_keys[alias]

def get_payload(self, certificate: str, token: str,
encoding: str = 'utf-8') -> Dict[str, Any]:
certificate_hash = sha1(certificate.encode('utf-8')).hexdigest()
public_key = self._get_public_key(certificate_hash, certificate,
encoding=encoding)
token_hash = sha1(token.encode('utf-8')).hexdigest()

if certificate_hash not in self._payload_cache:
self._payload_cache[certificate_hash] = {}

if token_hash not in self._payload_cache[certificate_hash]:
headers = jwt.get_unverified_header(token)
self._payload_cache[certificate_hash][token_hash] = \
jwt.decode(token, public_key, algorithms=headers['alg'])

return self._payload_cache[certificate_hash][token_hash]

def get_jwks(self, key_id: str, algorithm: str = 'RS256') -> Dict[str, str]:
key = self._keys.get_key(key_id)
if key:
Expand Down
27 changes: 0 additions & 27 deletions tests/unit/confidant/services/jwkmanager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from unittest.mock import patch, Mock

from confidant.services.jwkmanager import jwk_manager
from calendar import timegm


def test_set_key(test_key_pair):
Expand Down Expand Up @@ -126,32 +125,6 @@ def test_get_jwt_raises_no_key_id(test_key_pair, test_jwk_payload):
jwk_manager.get_jwt('non-existent', test_jwk_payload)


@patch('jwt.api_jwt.PyJWT._validate_exp', return_value=True)
def test_get_payload(mock_validate, test_key_pair, test_jwk_payload, test_jwt,
test_certificate):
test_private_key = test_key_pair.export_to_pem(private_key=True,
password=None)
jwk_manager.set_key('test-key', test_private_key.decode('utf-8'))
result = jwk_manager.get_payload(test_certificate.decode('utf-8'),
test_jwt)
mocked_date = datetime.datetime(
year=2020,
month=10,
day=10,
hour=0,
minute=0,
second=0,
microsecond=0
)
test_jwk_payload.update({
'iat': timegm(mocked_date.utctimetuple()),
'nbf': timegm(mocked_date.utctimetuple()),
'exp': timegm((mocked_date +
datetime.timedelta(seconds=3600)).utctimetuple()),
})
assert result == test_jwk_payload


def test_get_jwks(test_key_pair, test_jwk_payload, test_jwt,
test_jwks):
test_private_key = test_key_pair.export_to_pem(private_key=True,
Expand Down

0 comments on commit 05406ec

Please sign in to comment.