Skip to content

Commit d6a8e71

Browse files
authored
Merge pull request #56 from vimalloc/user_loader
Add user_loader feature
2 parents 3f90d1f + efa38a1 commit d6a8e71

14 files changed

+408
-22
lines changed

docs/changing_default_behavior.rst

+6
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ Possible loader functions are:
3737
* - **revoked_token_loader**
3838
- Function to call when a revoked token accesses a protected endpoint
3939
- None
40+
* - **user_loader_callback_loader**
41+
- Function to call to load a user object from a token
42+
- Takes one argument - The identity of the token to load a user from
43+
* - **user_loader_error_loader**
44+
- Function that is called when the user_loader callback function returns **None**
45+
- Takes one argument - The identity of the user who failed to load
4046

4147
Dynamic token expires time
4248
~~~~~~~~~~~~~~~~~~~~~~~~~~

docs/complex_objects_from_token.rst

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Complex Objects from Tokens
2+
===========================
3+
4+
We can also do the inverse of creating tokens from complex objects like we did
5+
in the last section. In this case, we can take a token and every time a
6+
protected endpoint is accessed automatically use the token to load a complex
7+
object, for example a SQLAlchemy user object. Here's an example of how it
8+
might look:
9+
10+
.. literalinclude:: ../examples/complex_objects_from_tokens.py
11+
12+
If you do not provide a user_loader_callback in your application, and attempt
13+
to access the **current_user** LocalProxy, it will simply be None.
14+
15+
One thing to note with this is that you will now call the **user_loader_callback**
16+
on all of your protected endpoints, which will probably incur the cost of a
17+
database lookup. In most cases this likely isn't a big deal for your application,
18+
but do be aware that it could slow things down if your frontend is doing several
19+
calls to endpoints in rapid succession.

docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ documentation is coming soon!
1616
basic_usage
1717
add_custom_data_claims
1818
tokens_from_complex_object
19+
complex_objects_from_token
1920
refresh_tokens
2021
token_freshness
2122
changing_default_behavior
+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from flask import Flask, jsonify, request
2+
from flask_jwt_extended import (
3+
JWTManager, jwt_required, create_access_token, current_user
4+
)
5+
6+
app = Flask(__name__)
7+
app.secret_key = 'super-secret' # Change this!
8+
jwt = JWTManager(app)
9+
10+
11+
# A user object that we will load our tokens
12+
class UserObject:
13+
def __init__(self, username, roles):
14+
self.username = username
15+
self.roles = roles
16+
17+
# An example store of users. In production, this would likely
18+
# be a sqlalchemy instance or something similiar
19+
users_to_roles = {
20+
'foo': ['admin'],
21+
'bar': ['peasant'],
22+
'baz': ['peasant']
23+
}
24+
25+
26+
# This function is called whenever a protected endpoint is accessed.
27+
# This should return a complex object based on the token identity.
28+
# This is called after the token is verified, so you can use
29+
# get_jwt_claims() in here if desired. Note that this needs to
30+
# return None if the user could not be loaded for any reason,
31+
# such as not being found in the underlying data store
32+
@jwt.user_loader_callback_loader
33+
def user_loader_callback(identity):
34+
if identity not in users_to_roles:
35+
return None
36+
37+
return UserObject(
38+
username=identity,
39+
roles=users_to_roles[identity]
40+
)
41+
42+
43+
# You can override the error returned to the user if the
44+
# user_loader_callback returns None. By default, if you don't
45+
# override this, it will return a 401 status code with the json:
46+
# {'msg': "Error loading the user <identity>"}. You can use
47+
# get_jwt_claims() here too if desired
48+
@jwt.user_loader_error_loader
49+
def custom_user_loader_error(identity):
50+
return jsonify({"msg": "User not found"}), 404
51+
52+
53+
# Create a token for any user, so this can be tested out
54+
@app.route('/login', methods=['POST'])
55+
def login():
56+
username = request.json.get('username', None)
57+
access_token = create_access_token(identity=username)
58+
ret = {'access_token': access_token}
59+
return jsonify(ret), 200
60+
61+
62+
# If the user_loader_callback returns None, this method will
63+
# not get hit, even if the access token is valid. You can
64+
# access the loaded user via the ``current_user``` LocalProxy,
65+
# or with the ```get_current_user()``` method
66+
@app.route('/admin-only', methods=['GET'])
67+
@jwt_required
68+
def protected():
69+
if 'admin' not in current_user.roles:
70+
return jsonify({"msg": "Forbidden"}), 403
71+
return jsonify({"secret_msg": "don't forget to drink your ovaltine"})
72+
73+
if __name__ == '__main__':
74+
app.run()

flask_jwt_extended/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .utils import (
77
create_refresh_token, create_access_token, get_jwt_identity,
88
get_jwt_claims, set_access_cookies, set_refresh_cookies,
9-
unset_jwt_cookies, get_raw_jwt
9+
unset_jwt_cookies, get_raw_jwt, get_current_user, current_user
1010
)
1111
from .blacklist import (
1212
revoke_token, unrevoke_token, get_stored_tokens, get_all_stored_tokens,

flask_jwt_extended/default_callbacks.py

+9
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,12 @@ def default_revoked_token_callback():
7474
return a general error message with a 401 status code
7575
"""
7676
return jsonify({'msg': 'Token has been revoked'}), 401
77+
78+
79+
def default_user_loader_error_callback(identity):
80+
"""
81+
By default, if a user_loader callback is defined and the callback
82+
function returns None, we return a general error message with a 401
83+
status code
84+
"""
85+
return jsonify({'msg': "Error loading the user {}".format(identity)}), 401

flask_jwt_extended/exceptions.py

+8
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,11 @@ class FreshTokenRequired(JWTExtendedException):
5454
protected by fresh_jwt_required
5555
"""
5656
pass
57+
58+
59+
class UserLoadError(JWTExtendedException):
60+
"""
61+
Error raised when a user_loader callback function returns None, indicating
62+
that it cannot or will not load a user for the given identity.
63+
"""
64+
pass

flask_jwt_extended/jwt_manager.py

+60-6
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,18 @@
66
from flask_jwt_extended.config import config
77
from flask_jwt_extended.exceptions import (
88
JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError,
9-
RevokedTokenError, FreshTokenRequired, CSRFError
9+
RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError
1010
)
1111
from flask_jwt_extended.default_callbacks import (
1212
default_expired_token_callback, default_user_claims_callback,
1313
default_user_identity_callback, default_invalid_token_callback,
14-
default_unauthorized_callback,
15-
default_needs_fresh_token_callback,
16-
default_revoked_token_callback
14+
default_unauthorized_callback, default_needs_fresh_token_callback,
15+
default_revoked_token_callback, default_user_loader_error_callback
1716
)
1817
from flask_jwt_extended.tokens import (
19-
encode_refresh_token, decode_jwt,
20-
encode_access_token
18+
encode_refresh_token, decode_jwt, encode_access_token
2119
)
20+
from flask_jwt_extended.utils import get_jwt_identity
2221

2322

2423
class JWTManager(object):
@@ -39,6 +38,8 @@ def __init__(self, app=None):
3938
self._unauthorized_callback = default_unauthorized_callback
4039
self._needs_fresh_token_callback = default_needs_fresh_token_callback
4140
self._revoked_token_callback = default_revoked_token_callback
41+
self._user_loader_callback = None
42+
self._user_loader_error_callback = default_user_loader_error_callback
4243

4344
# Register this extension with the flask app now (if it is provided)
4445
if app is not None:
@@ -101,6 +102,14 @@ def handle_revoked_token_error(e):
101102
def handle_fresh_token_required(e):
102103
return self._needs_fresh_token_callback()
103104

105+
@app.errorhandler(UserLoadError)
106+
def handler_user_load_error(e):
107+
# The identity is already saved before this exception was raised,
108+
# otherwise a different exception would be raised, which is why we
109+
# can safely call get_jwt_identity() here
110+
identity = get_jwt_identity()
111+
return self._user_loader_error_callback(identity)
112+
104113
@staticmethod
105114
def _set_default_configuration_options(app):
106115
"""
@@ -244,6 +253,50 @@ def revoked_token_loader(self, callback):
244253
self._revoked_token_callback = callback
245254
return callback
246255

256+
def user_loader_callback_loader(self, callback):
257+
"""
258+
Sets the callback method to be called to load a user on a protected
259+
endpoint.
260+
261+
By default this is not is not used.
262+
263+
If a callback method is passed in here, it must take one argument,
264+
which is the identity of the user to load. It must return the user
265+
object, or None in the case of an error (which will cause the TODO
266+
error handler to be hit)
267+
"""
268+
self._user_loader_callback = callback
269+
return callback
270+
271+
def user_loader_error_loader(self, callback):
272+
"""
273+
Sets the callback method to be called if a user fails or is refused
274+
to load when calling the _user_loader_callback function (indicated by
275+
that function returning None)
276+
277+
The default implementation will return json:
278+
'{"msg": "Error loading the user <identity>"}' with a 400 status code.
279+
280+
Callback must be a function that takes one argument, the identity of the
281+
user who failed to load.
282+
"""
283+
self._user_loader_error_callback = callback
284+
return callback
285+
286+
def has_user_loader(self):
287+
"""
288+
Returns True if a user_loader_callback has been defined in this
289+
application, False otherwise
290+
"""
291+
return self._user_loader_callback is not None
292+
293+
def user_loader(self, identity):
294+
"""
295+
Calls the _user_loader_callback function (if it is defined) and returns
296+
the resulting user from this callback.
297+
"""
298+
return self._user_loader_callback(identity)
299+
247300
def create_refresh_token(self, identity, expires_delta=None):
248301
"""
249302
Creates a new refresh token
@@ -315,3 +368,4 @@ def create_access_token(self, identity, fresh=False, expires_delta=None):
315368
config.algorithm, csrf=config.csrf_protect)
316369
store_token(decoded_token, revoked=False)
317370
return access_token
371+

flask_jwt_extended/utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from flask import current_app
2+
from werkzeug.local import LocalProxy
3+
24
try:
35
from flask import _app_ctx_stack as ctx_stack
46
except ImportError: # pragma: no cover
@@ -8,6 +10,10 @@
810
from flask_jwt_extended.tokens import decode_jwt
911

1012

13+
# Proxy to access the current user
14+
current_user = LocalProxy(lambda: get_current_user())
15+
16+
1117
def get_raw_jwt():
1218
"""
1319
Returns the python dictionary which has all of the data in this JWT. If no
@@ -32,6 +38,15 @@ def get_jwt_claims():
3238
return get_raw_jwt().get('user_claims', {})
3339

3440

41+
def get_current_user():
42+
"""
43+
Returns the loaded user from a user_loader callback in a protected endpoint.
44+
If no user was loaded, or if no user_loader callback was defined, this will
45+
return None
46+
"""
47+
return getattr(ctx_stack.top, 'jwt_user', None)
48+
49+
3550
def get_jti(encoded_token):
3651
"""
3752
Returns the JTI given the JWT encoded token
@@ -60,6 +75,16 @@ def create_refresh_token(*args, **kwargs):
6075
return jwt_manager.create_refresh_token(*args, **kwargs)
6176

6277

78+
def user_loader(*args, **kwargs):
79+
jwt_manager = _get_jwt_manager()
80+
return jwt_manager.user_loader(*args, **kwargs)
81+
82+
83+
def has_user_loader(*args, **kwargs):
84+
jwt_manager = _get_jwt_manager()
85+
return jwt_manager.has_user_loader(*args, **kwargs)
86+
87+
6388
def get_csrf_token(encoded_token):
6489
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
6590
return token['csrf']

flask_jwt_extended/view_decorators.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from flask_jwt_extended.config import config
1212
from flask_jwt_extended.exceptions import (
1313
InvalidHeaderError, NoAuthorizationError, WrongTokenError,
14-
FreshTokenRequired, CSRFError
14+
FreshTokenRequired, CSRFError, UserLoadError
1515
)
1616
from flask_jwt_extended.tokens import decode_jwt
17+
from flask_jwt_extended.utils import has_user_loader, user_loader
1718

1819

1920
def jwt_required(fn):
@@ -28,10 +29,9 @@ def jwt_required(fn):
2829
"""
2930
@wraps(fn)
3031
def wrapper(*args, **kwargs):
31-
# Save the jwt in the context so that it can be accessed later by
32-
# the various endpoints that is using this decorator
3332
jwt_data = _decode_jwt_from_request(request_type='access')
3433
ctx_stack.top.jwt = jwt_data
34+
_load_user(jwt_data['identity'])
3535
return fn(*args, **kwargs)
3636
return wrapper
3737

@@ -49,15 +49,11 @@ def jwt_optional(fn):
4949
@wraps(fn)
5050
def wrapper(*args, **kwargs):
5151
try:
52-
# If an acceptable JWT is found in the request, put it into
53-
# the application context
5452
jwt_data = _decode_jwt_from_request(request_type='access')
5553
ctx_stack.top.jwt = jwt_data
54+
_load_user(jwt_data['identity'])
5655
except NoAuthorizationError:
57-
# Allow request to proceed if no authorization header is present
58-
# in the request, but don't modify application context
5956
pass
60-
# Return the decorated function in either case
6157
return fn(*args, **kwargs)
6258
return wrapper
6359

@@ -78,9 +74,8 @@ def wrapper(*args, **kwargs):
7874
if not jwt_data['fresh']:
7975
raise FreshTokenRequired('Fresh token required')
8076

81-
# Save the jwt in the context so that it can be accessed later by
82-
# the various endpoints that is using this decorator
8377
ctx_stack.top.jwt = jwt_data
78+
_load_user(jwt_data['identity'])
8479
return fn(*args, **kwargs)
8580
return wrapper
8681

@@ -93,14 +88,22 @@ def jwt_refresh_token_required(fn):
9388
"""
9489
@wraps(fn)
9590
def wrapper(*args, **kwargs):
96-
# Save the jwt in the context so that it can be accessed later by
97-
# the various endpoints that is using this decorator
9891
jwt_data = _decode_jwt_from_request(request_type='refresh')
9992
ctx_stack.top.jwt = jwt_data
93+
_load_user(jwt_data['identity'])
10094
return fn(*args, **kwargs)
10195
return wrapper
10296

10397

98+
def _load_user(identity):
99+
if has_user_loader():
100+
user = user_loader(identity)
101+
if user is None:
102+
raise UserLoadError("user_loader returned None for {}".format(identity))
103+
else:
104+
ctx_stack.top.jwt_user = user
105+
106+
104107
def _decode_jwt_from_headers():
105108
header_name = config.header_name
106109
header_type = config.header_type

tests/test_blacklist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def setUp(self):
2929

3030
@self.app.route('/auth/login', methods=['POST'])
3131
def login():
32-
username = request.json['username']
32+
username = request.get_json()['username']
3333
ret = {
3434
'access_token': create_access_token(username, fresh=True),
3535
'refresh_token': create_refresh_token(username)

0 commit comments

Comments
 (0)