Skip to content

Commit 4616541

Browse files
committed
Adds callback methods for verifying the user claims of an access token
refs #64 #70
1 parent 7b5016f commit 4616541

9 files changed

+186
-5
lines changed

docs/changing_default_behavior.rst

+6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ Possible loader functions are:
4343
* - **user_loader_error_loader**
4444
- Function that is called when the user_loader callback function returns **None**
4545
- Takes one argument - The identity of the user who failed to load
46+
* - **claims_verification_loader**
47+
- Function that is called to verify the custom **user_claims** data. Must return True or False
48+
- Takes one argument - The custom user_claims dict in an access token
49+
* - **claims_verification_failed_loader**
50+
- Function that is called when the user claims verification callback returns False
51+
- None
4652

4753
Dynamic token expires time
4854
~~~~~~~~~~~~~~~~~~~~~~~~~~

flask_jwt_extended/default_callbacks.py

+15
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,18 @@ def default_user_loader_error_callback(identity):
8383
status code
8484
"""
8585
return jsonify({'msg': "Error loading the user {}".format(identity)}), 401
86+
87+
88+
def default_claims_verification_callback(user_claims):
89+
"""
90+
By default, we do not do any verification of the user claims.
91+
"""
92+
return True
93+
94+
95+
def default_claims_verification_failed_callback():
96+
"""
97+
By default, if the user claims verification failed, we return a generic
98+
error message with a 400 status code
99+
"""
100+
return jsonify({'msg': 'User claims verification failed'}), 400

flask_jwt_extended/exceptions.py

+8
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,11 @@ class UserLoadError(JWTExtendedException):
6262
that it cannot or will not load a user for the given identity.
6363
"""
6464
pass
65+
66+
67+
class UserClaimsVerificationError(JWTExtendedException):
68+
"""
69+
Error raised when the claims_verification_callback function returns False,
70+
indicating that the expected user claims are invalid
71+
"""
72+
pass

flask_jwt_extended/jwt_manager.py

+36-2
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
from flask_jwt_extended.config import config
66
from flask_jwt_extended.exceptions import (
77
JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError,
8-
RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError
8+
RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError,
9+
UserClaimsVerificationError
910
)
1011
from flask_jwt_extended.default_callbacks import (
1112
default_expired_token_callback, default_user_claims_callback,
1213
default_user_identity_callback, default_invalid_token_callback,
1314
default_unauthorized_callback, default_needs_fresh_token_callback,
14-
default_revoked_token_callback, default_user_loader_error_callback
15+
default_revoked_token_callback, default_user_loader_error_callback,
16+
default_claims_verification_callback,
17+
default_claims_verification_failed_callback
1518
)
1619
from flask_jwt_extended.tokens import (
1720
encode_refresh_token, encode_access_token
@@ -40,6 +43,8 @@ def __init__(self, app=None):
4043
self._user_loader_callback = None
4144
self._user_loader_error_callback = default_user_loader_error_callback
4245
self._token_in_blacklist_callback = None
46+
self._claims_verification_callback = default_claims_verification_callback
47+
self._claims_verification_failed_callback = default_claims_verification_failed_callback
4348

4449
# Register this extension with the flask app now (if it is provided)
4550
if app is not None:
@@ -110,6 +115,10 @@ def handler_user_load_error(e):
110115
identity = get_jwt_identity()
111116
return self._user_loader_error_callback(identity)
112117

118+
@app.errorhandler(UserClaimsVerificationError)
119+
def handle_failed_user_claims_verification(e):
120+
return self._claims_verification_failed_callback()
121+
113122
@staticmethod
114123
def _set_default_configuration_options(app):
115124
"""
@@ -296,6 +305,31 @@ def token_in_blacklist_loader(self, callback):
296305
self._token_in_blacklist_callback = callback
297306
return callback
298307

308+
def claims_verification_loader(self, callback):
309+
"""
310+
Sets the callback function for checking if the custom user claims are
311+
valid for this access token.
312+
313+
This callback function must take one parameter, which is the custom
314+
user claims present in the access token. This callback function should
315+
return True if the user claims are valid, False otherwise.
316+
"""
317+
self._claims_verification_callback = callback
318+
return callback
319+
320+
def claims_verification_failed_loader(self, callback):
321+
"""
322+
Sets the callback method to be called if the user claims verification
323+
method returns False, indicating that the user claims are not valid.
324+
325+
The default implementation will return the json:
326+
'{"msg": "User claims verification failed"})' with a 400 status code
327+
328+
Callback must be a function that takes no arguments.
329+
"""
330+
self._claims_verification_failed_callback = callback
331+
return callback
332+
299333
def create_refresh_token(self, identity, expires_delta=None):
300334
"""
301335
Creates a new refresh token

flask_jwt_extended/utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def token_in_blacklist(*args, **kwargs):
106106
return jwt_manager._token_in_blacklist_callback(*args, **kwargs)
107107

108108

109+
def verify_token_claims(*args, **kwargs):
110+
jwt_manager = _get_jwt_manager()
111+
return jwt_manager._claims_verification_callback(*args, **kwargs)
112+
113+
109114
def get_csrf_token(encoded_token):
110115
token = decode_jwt(
111116
encoded_token,

flask_jwt_extended/view_decorators.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
from flask_jwt_extended.config import config
1111
from flask_jwt_extended.exceptions import (
1212
InvalidHeaderError, NoAuthorizationError, WrongTokenError,
13-
FreshTokenRequired, CSRFError, UserLoadError, RevokedTokenError
13+
FreshTokenRequired, CSRFError, UserLoadError, RevokedTokenError,
14+
UserClaimsVerificationError
1415
)
1516
from flask_jwt_extended.tokens import decode_jwt
1617
from flask_jwt_extended.utils import (
1718
has_user_loader, user_loader, token_in_blacklist,
18-
has_token_in_blacklist_callback
19+
has_token_in_blacklist_callback, verify_token_claims
1920
)
2021

2122

@@ -207,6 +208,11 @@ def _decode_jwt_from_request(request_type):
207208
if decoded_token['type'] != request_type:
208209
raise WrongTokenError('Only {} tokens can access this endpoint'.format(request_type))
209210

211+
# Check if the custom claims in access tokens are valid
212+
if request_type == 'access':
213+
if not verify_token_claims(decoded_token['user_claims']):
214+
raise UserClaimsVerificationError('user_claims verification failed')
215+
210216
# If blacklisting is enabled, see if this token has been revoked
211217
if _token_blacklisted(decoded_token, request_type):
212218
raise RevokedTokenError('Token has been revoked')

tests/test_jwt_manager.py

+21
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,24 @@ def custom_user_loader_error(identity):
214214

215215
self.assertEqual(status_code, 404)
216216
self.assertEqual(data, {'msg': 'Not found'})
217+
218+
def test_claims_verification(self):
219+
with self.app.test_request_context():
220+
m = JWTManager(self.app)
221+
222+
@m.claims_verification_loader
223+
def user_claims_verification(claims):
224+
return 'foo' in claims
225+
226+
@m.claims_verification_failed_loader
227+
def user_claims_verification_failed():
228+
return jsonify({'msg': 'Test'}), 404
229+
230+
result = m._claims_verification_callback({'bar': 'baz'})
231+
self.assertEqual(result, False)
232+
233+
result = m._claims_verification_failed_callback()
234+
status_code, data = self._parse_callback_result(result)
235+
236+
self.assertEqual(status_code, 404)
237+
self.assertEqual(data, {'msg': 'Test'})

tests/test_protected_endpoints.py

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def setUp(self):
2222
self.app.config['JWT_ALGORITHM'] = 'HS256'
2323
self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(seconds=1)
2424
self.app.config['JWT_REFRESH_TOKEN_EXPIRES'] = timedelta(seconds=1)
25-
self.app.config['JWT_IDENTITY_CLAIM'] = 'sub'
2625
self.jwt_manager = JWTManager(self.app)
2726
self.client = self.app.test_client()
2827

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import unittest
2+
3+
from flask import Flask, jsonify, json
4+
5+
from flask_jwt_extended import JWTManager, create_access_token, jwt_required
6+
7+
8+
class TestUserClaimsVerification(unittest.TestCase):
9+
10+
def setUp(self):
11+
self.app = Flask(__name__)
12+
self.app.secret_key = 'super=secret'
13+
self.jwt_manager = JWTManager(self.app)
14+
self.client = self.app.test_client()
15+
16+
@self.jwt_manager.claims_verification_loader
17+
def claims_verification(user_claims):
18+
expected_keys = ['foo', 'bar']
19+
for key in expected_keys:
20+
if key not in user_claims:
21+
return False
22+
return True
23+
24+
@self.app.route('/auth/login', methods=['POST'])
25+
def login():
26+
ret = {'access_token': create_access_token('test')}
27+
return jsonify(ret), 200
28+
29+
@self.app.route('/protected')
30+
@jwt_required
31+
def protected():
32+
return jsonify({'msg': "hello world"})
33+
34+
def _jwt_get(self, url, jwt, header_name='Authorization', header_type='Bearer'):
35+
header_type = '{} {}'.format(header_type, jwt).strip()
36+
response = self.client.get(url, headers={header_name: header_type})
37+
status_code = response.status_code
38+
data = json.loads(response.get_data(as_text=True))
39+
return status_code, data
40+
41+
def test_valid_user_claims(self):
42+
@self.jwt_manager.user_claims_loader
43+
def user_claims_callback(identity):
44+
return {'foo': 'baz', 'bar': 'boom'}
45+
46+
response = self.client.post('/auth/login')
47+
data = json.loads(response.get_data(as_text=True))
48+
access_token = data['access_token']
49+
50+
status, data = self._jwt_get('/protected', access_token)
51+
self.assertEqual(data, {'msg': 'hello world'})
52+
self.assertEqual(status, 200)
53+
54+
def test_empty_claims_verification_error(self):
55+
response = self.client.post('/auth/login')
56+
data = json.loads(response.get_data(as_text=True))
57+
access_token = data['access_token']
58+
59+
status, data = self._jwt_get('/protected', access_token)
60+
self.assertEqual(data, {'msg': 'User claims verification failed'})
61+
self.assertEqual(status, 400)
62+
63+
def test_bad_claims_verification_error(self):
64+
@self.jwt_manager.user_claims_loader
65+
def user_claims_callback(identity):
66+
return {'super': 'banana'}
67+
68+
response = self.client.post('/auth/login')
69+
data = json.loads(response.get_data(as_text=True))
70+
access_token = data['access_token']
71+
72+
status, data = self._jwt_get('/protected', access_token)
73+
self.assertEqual(data, {'msg': 'User claims verification failed'})
74+
self.assertEqual(status, 400)
75+
76+
def test_bad_claims_custom_error_callback(self):
77+
@self.jwt_manager.claims_verification_failed_loader
78+
def user_claims_callback():
79+
return jsonify({'foo': 'bar'}), 404
80+
81+
response = self.client.post('/auth/login')
82+
data = json.loads(response.get_data(as_text=True))
83+
access_token = data['access_token']
84+
85+
status, data = self._jwt_get('/protected', access_token)
86+
self.assertEqual(data, {'foo': 'bar'})
87+
self.assertEqual(status, 404)

0 commit comments

Comments
 (0)