From 00cd759d86aae24176ead7bdbed273a07532443e Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Wed, 2 Nov 2022 06:01:52 -0500 Subject: [PATCH] Add `Algorithm.compute_hash_digest` and use it to implement at_hash validation example (#775) * Add compute_hash_digest to Algorithm objects `Algorithm.compute_hash_digest` is defined as a method which inspects the object to see that it has the requisite attributes, `hash_alg`. If `hash_alg` is not set, then the method raises a NotImplementedError. This applies to classes like NoneAlgorithm. If `hash_alg` is set, then it is checked for ``` has_crypto # is cryptography available? and isinstance(hash_alg, type) and issubclass(hash_alg, hashes.HashAlgorithm) ``` to see which API for computing a digest is appropriate -- `hashlib` vs `cryptography.hazmat.primitives.hashes`. These checks could be avoided at runtime if it were necessary to optimize further (e.g. attach compute_hash_digest methods to classes with a class decorator) but this is not clearly a worthwhile optimization. Such perf tuning is intentionally omitted for now. * Add doc example of OIDC login flow The goal of this doc example is to demonstrate usage of `get_algorithm_by_name` and `compute_hash_digest` for the purpose of `at_hash` validation. It is not meant to be a "guaranteed correct" and spec-compliant example. closes #314 --- CHANGELOG.rst | 4 +++ docs/usage.rst | 71 ++++++++++++++++++++++++++++++++++++++++ jwt/algorithms.py | 23 +++++++++++++ tests/test_algorithms.py | 23 +++++++++++++ 4 files changed, 121 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8ec7ede8..4d562070 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,10 @@ Fixed Added ~~~~~ +- Add ``compute_hash_digest`` as a method of ``Algorithm`` objects, which uses + the underlying hash algorithm to compute a digest. If there is no appropriate + hash algorithm, a ``NotImplementedError`` will be raised + `v2.6.0 `__ ----------------------------------------------------------------------- diff --git a/docs/usage.rst b/docs/usage.rst index 91d96791..a85fa188 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -297,3 +297,74 @@ Retrieve RSA signing keys from a JWKS endpoint ... ) >>> print(data) {'iss': 'https://dev-87evx9ru.auth0.com/', 'sub': 'aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC@clients', 'aud': 'https://expenses-api', 'iat': 1572006954, 'exp': 1572006964, 'azp': 'aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC', 'gty': 'client-credentials'} + +OIDC Login Flow +--------------- + +The following usage demonstrates an OIDC login flow using pyjwt. Further +reading about the OIDC spec is recommended for implementers. + +In particular, this demonstrates validation of the ``at_hash`` claim. +This claim relies on data from outside of the the JWT for validation. Methods +are provided which support computation and validation of this claim, but it +is not built into pyjwt. + +.. code-block:: python + + import base64 + import jwt + import requests + + + # Part 1: setup + # get the OIDC config and JWKs to use + + # in OIDC, you must know your client_id (this is the OAuth 2.0 client_id) + client_id = ... + + # example of fetching data from your OIDC server + # see: https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig + oidc_server = ... + oidc_config = requests.get( + f"https://{oidc_server}/.well-known/openid-configuration" + ).json() + signing_algos = oidc_config["id_token_signing_alg_values_supported"] + + # setup a PyJWKClient to get the appropriate signing key + jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"]) + + + # Part 2: login / authorization + # when a user completes an OIDC login flow, there will be a well-formed + # response object to parse/handle + + # data from the login flow + # see: https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse + token_response = ... + id_token = token_response["id_token"] + access_token = token_response["access_token"] + + + # Part 3: decode and validate at_hash + # after the login is complete, the id_token needs to be decoded + # this is the stage at which an OIDC client must verify the at_hash + + # get signing_key from id_token + signing_key = jwks_client.get_signing_key_from_jwt(id_token) + + # now, decode_complete to get payload + header + data = jwt.decode_complete( + id_token, + key=signing_key.key, + algorithms=signing_algos, + audience=client_id, + ) + payload, header = data["payload"], data["header"] + + # get the pyjwt algorithm object + alg_obj = jwt.get_algorithm_by_name(header["alg"]) + + # compute at_hash, then validate / assert + digest = alg_obj.compute_hash_digest(access_token) + at_hash = base64.urlsafe_b64encode(digest[: (len(digest) // 2)]).rstrip("=") + assert at_hash == payload["at_hash"] diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 93fadf4c..4fae441f 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -18,6 +18,7 @@ try: import cryptography.exceptions from cryptography.exceptions import InvalidSignature + from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec, padding from cryptography.hazmat.primitives.asymmetric.ec import ( @@ -111,6 +112,28 @@ class Algorithm: The interface for an algorithm used to sign and verify tokens. """ + def compute_hash_digest(self, bytestr: bytes) -> bytes: + """ + Compute a hash digest using the specified algorithm's hash algorithm. + + If there is no hash algorithm, raises a NotImplementedError. + """ + # lookup self.hash_alg if defined in a way that mypy can understand + hash_alg = getattr(self, "hash_alg", None) + if hash_alg is None: + raise NotImplementedError + + if ( + has_crypto + and isinstance(hash_alg, type) + and issubclass(hash_alg, hashes.HashAlgorithm) + ): + digest = hashes.Hash(hash_alg(), backend=default_backend()) + digest.update(bytestr) + return digest.finalize() + else: + return hash_alg(bytestr).digest() + def prepare_key(self, key): """ Performs necessary validation and conversions on the key and returns diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 538078af..894ce282 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -45,6 +45,12 @@ def test_algorithm_should_throw_exception_if_from_jwk_not_impl(self): with pytest.raises(NotImplementedError): algo.to_jwk("value") + def test_algorithm_should_throw_exception_if_compute_hash_digest_not_impl(self): + algo = Algorithm() + + with pytest.raises(NotImplementedError): + algo.compute_hash_digest(b"value") + def test_none_algorithm_should_throw_exception_if_key_is_not_none(self): algo = NoneAlgorithm() @@ -1054,3 +1060,20 @@ def test_okp_ed448_to_jwk_works_with_from_jwk(self): signature_2 = algo.sign(b"Hello World!", priv_key_2) assert algo.verify(b"Hello World!", pub_key_2, signature_1) assert algo.verify(b"Hello World!", pub_key_2, signature_2) + + @crypto_required + def test_rsa_can_compute_digest(self): + # this is the well-known sha256 hash of "foo" + foo_hash = base64.b64decode(b"LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=") + + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + computed_hash = algo.compute_hash_digest(b"foo") + assert computed_hash == foo_hash + + def test_hmac_can_compute_digest(self): + # this is the well-known sha256 hash of "foo" + foo_hash = base64.b64decode(b"LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564=") + + algo = HMACAlgorithm(HMACAlgorithm.SHA256) + computed_hash = algo.compute_hash_digest(b"foo") + assert computed_hash == foo_hash