Skip to content

Commit 05a802a

Browse files
committed
Add JWT_DECODE_ISSUER option
Closes #259
1 parent f39a679 commit 05a802a

File tree

6 files changed

+62
-6
lines changed

6 files changed

+62
-6
lines changed

flask_jwt_extended/config.py

+4
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,10 @@ def json_encoder(self):
316316
def audience(self):
317317
return current_app.config['JWT_DECODE_AUDIENCE']
318318

319+
@property
320+
def issuer(self):
321+
return current_app.config['JWT_DECODE_ISSUER']
322+
319323
@property
320324
def leeway(self):
321325
return current_app.config['JWT_DECODE_LEEWAY']

flask_jwt_extended/jwt_manager.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import datetime
22
from warnings import warn
33

4-
from jwt import ExpiredSignatureError, InvalidTokenError, InvalidAudienceError
4+
from jwt import (
5+
ExpiredSignatureError, InvalidTokenError, InvalidAudienceError,
6+
InvalidIssuerError
7+
)
58
try:
69
from flask import _app_ctx_stack as ctx_stack
710
except ImportError: # pragma: no cover
@@ -126,6 +129,10 @@ def handle_wrong_token_error(e):
126129
def handle_invalid_audience_error(e):
127130
return self._invalid_token_callback(str(e))
128131

132+
@app.errorhandler(InvalidIssuerError)
133+
def handle_invalid_issuer_error(e):
134+
return self._invalid_token_callback(str(e))
135+
129136
@app.errorhandler(RevokedTokenError)
130137
def handle_revoked_token_error(e):
131138
return self._revoked_token_callback()
@@ -214,6 +221,7 @@ def _set_default_configuration_options(app):
214221
app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
215222
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')
216223
app.config.setdefault('JWT_DECODE_AUDIENCE', None)
224+
app.config.setdefault('JWT_DECODE_ISSUER', None)
217225
app.config.setdefault('JWT_DECODE_LEEWAY', 0)
218226

219227
app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)

flask_jwt_extended/tokens.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
114114

115115
def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,
116116
user_claims_key, csrf_value=None, audience=None,
117-
leeway=0, allow_expired=False):
117+
leeway=0, allow_expired=False, issuer=None):
118118
"""
119119
Decodes an encoded JWT
120120
@@ -125,6 +125,7 @@ def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,
125125
:param user_claims_key: expected key that contains the user claims
126126
:param csrf_value: Expected double submit csrf value
127127
:param audience: expected audience in the JWT
128+
:param issuer: expected issuer in the JWT
128129
:param leeway: optional leeway to add some margin around expiration times
129130
:param allow_expired: Options to ignore exp claim validation in token
130131
:return: Dictionary containing contents of the JWT
@@ -135,7 +136,7 @@ def decode_jwt(encoded_token, secret, algorithms, identity_claim_key,
135136

136137
# This call verifies the ext, iat, nbf, and aud claims
137138
data = jwt.decode(encoded_token, secret, algorithms=algorithms, audience=audience,
138-
leeway=leeway, options=options)
139+
leeway=leeway, options=options, issuer=issuer)
139140

140141
# Make sure that any custom claims we expect in the token are present
141142
if 'jti' not in data:

flask_jwt_extended/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
103103
user_claims_key=config.user_claims_key,
104104
csrf_value=csrf_value,
105105
audience=config.audience,
106+
issuer=config.issuer,
106107
leeway=config.leeway,
107108
allow_expired=allow_expired
108109
)
@@ -115,6 +116,7 @@ def decode_token(encoded_token, csrf_value=None, allow_expired=False):
115116
user_claims_key=config.user_claims_key,
116117
csrf_value=csrf_value,
117118
audience=config.audience,
119+
issuer=config.issuer,
118120
leeway=config.leeway,
119121
allow_expired=True
120122
)

tests/test_decode_tokens.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from jwt import (
1010
ExpiredSignatureError, InvalidSignatureError, InvalidAudienceError,
11-
ImmatureSignatureError
11+
ImmatureSignatureError, InvalidIssuerError
1212
)
1313

1414
from flask_jwt_extended import (
@@ -246,9 +246,9 @@ def test_valid_aud(app, default_access_token, token_aud):
246246
app.config['JWT_DECODE_AUDIENCE'] = ['foo', 'bar']
247247

248248
default_access_token['aud'] = token_aud
249-
invalid_token = encode_token(app, default_access_token)
249+
valid_token = encode_token(app, default_access_token)
250250
with app.test_request_context():
251-
decoded = decode_token(invalid_token)
251+
decoded = decode_token(valid_token)
252252
assert decoded['aud'] == token_aud
253253

254254

@@ -261,3 +261,21 @@ def test_invalid_aud(app, default_access_token, token_aud):
261261
with pytest.raises(InvalidAudienceError):
262262
with app.test_request_context():
263263
decode_token(invalid_token)
264+
265+
def test_valid_iss(app, default_access_token):
266+
app.config['JWT_DECODE_ISSUER'] = 'foobar'
267+
268+
default_access_token['iss'] = 'foobar'
269+
valid_token = encode_token(app, default_access_token)
270+
with app.test_request_context():
271+
decoded = decode_token(valid_token)
272+
assert decoded['iss'] == 'foobar'
273+
274+
def test_invalid_iss(app, default_access_token):
275+
app.config['JWT_DECODE_ISSUER'] = 'baz'
276+
277+
default_access_token['iss'] = 'foobar'
278+
invalid_token = encode_token(app, default_access_token)
279+
with pytest.raises(InvalidIssuerError):
280+
with app.test_request_context():
281+
decode_token(invalid_token)

tests/test_view_decorators.py

+23
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,29 @@ def test_jwt_invalid_audience(app):
237237
assert response.status_code == 422
238238
assert response.get_json() == {'msg': 'Invalid audience'}
239239

240+
def test_jwt_invalid_issuer(app):
241+
url = '/protected'
242+
jwtM = get_jwt_manager(app)
243+
test_client = app.test_client()
244+
245+
# No issuer claim expected or provided - OK
246+
access_token = encode_token(app, {'identity': 'me'})
247+
response = test_client.get(url, headers=make_headers(access_token))
248+
assert response.status_code == 200
249+
250+
# Issuer claim expected and not provided - not OK
251+
app.config['JWT_DECODE_ISSUER'] = 'my_issuer'
252+
access_token = encode_token(app, {'identity': 'me'})
253+
response = test_client.get(url, headers=make_headers(access_token))
254+
assert response.status_code == 422
255+
assert response.get_json() == {'msg': 'Token is missing the "iss" claim'}
256+
257+
# Issuer claim still expected and wrong one provided - not OK
258+
access_token = encode_token(app, {'iss': 'different_issuer', 'identity': 'me'})
259+
response = test_client.get(url, headers=make_headers(access_token))
260+
assert response.status_code == 422
261+
assert response.get_json() == {'msg': 'Invalid issuer'}
262+
240263

241264
@pytest.mark.parametrize("delta_func", [timedelta, relativedelta])
242265
def test_expired_token(app, delta_func):

0 commit comments

Comments
 (0)