Skip to content

Commit 10780a1

Browse files
committed
Better way to handle complex object, i hope (refs #11)
1 parent 26e79df commit 10780a1

File tree

4 files changed

+37
-21
lines changed

4 files changed

+37
-21
lines changed

flask_jwt_extended/jwt_manager.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
class JWTManager:
99
def __init__(self, app=None):
1010
# Function that will be called to add custom user claims to a JWT.
11-
self.user_claims_callback = lambda _: {}
11+
self._user_claims_callback = lambda _: {}
12+
13+
# Function that will be called to return an identity from an object
14+
self._user_identity_callback = lambda i: i
1215

1316
# Function that will be called when an expired token is received
1417
self._expired_token_callback = lambda: (
@@ -90,7 +93,19 @@ def user_claims_loader(self, callback):
9093
Callback must be a function that takes only one argument, which is the
9194
identity of the JWT being created.
9295
"""
93-
self.user_claims_callback = callback
96+
self._user_claims_callback = callback
97+
return callback
98+
99+
def user_identity_loader(self, callback):
100+
"""
101+
This sets the callback method for adding custom user claims to a JWT.
102+
103+
By default, no extra user claims will be added to the JWT.
104+
105+
Callback must be a function that takes only one argument, which is the
106+
identity of the JWT being created.
107+
"""
108+
self._user_identity_callback = callback
94109
return callback
95110

96111
def expired_token_loader(self, callback):

flask_jwt_extended/utils.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -312,30 +312,28 @@ def create_refresh_token(identity):
312312
return refresh_token
313313

314314

315-
def create_access_token(identity, fresh=False, identity_lookup=None):
315+
def create_access_token(identity, fresh=False):
316316
"""
317317
Creates a new access token
318318
319319
:param identity: The identity of this token. This can be any data that is
320320
json serializable. It can also be an object, in which case
321-
you can pass a function to identity_lookup which tells us
322-
how to get the identity out of this object. This is useful
323-
so you don't need to query disk twice, once for initially
324-
finding the identity in your login endpoint, and once for
325-
setting addition data in the JWT via the user_claims_loader
321+
you can use the user_identity_loader to define a function
322+
that will be called to pull a json serializable identity
323+
out of this object. This is useful so you don't need to
324+
query disk twice, once for initially finding the identity
325+
in your login endpoint, and once for setting addition data
326+
in the JWT via the user_claims_loader
326327
:param fresh: If this token should me markded as fresh, and can thus access
327328
fresh_jwt_required protected endpoints. Defaults to False
328-
:param identity_lookup: Function to generate a json serilizable identity
329-
from the identity object
330329
:return: A newly encoded JWT access token
331330
"""
332331
# Token options
333332
secret = _get_secret_key()
334333
access_expire_delta = get_access_expires()
335334
algorithm = get_algorithm()
336-
user_claims = current_app.jwt_manager.user_claims_callback(identity)
337-
if identity_lookup:
338-
identity = identity_lookup(identity)
335+
user_claims = current_app.jwt_manager._user_claims_callback(identity)
336+
identity = current_app.jwt_manager._user_identity_callback(identity)
339337

340338
access_token = _encode_access_token(identity, secret, algorithm, access_expire_delta,
341339
fresh=fresh, user_claims=user_claims)

tests/test_jwt_encode_decode.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def test_create_access_token_with_object(self):
336336
# Complex object to test building a JWT from. Normally if you are using
337337
# this functionality, this is something that would be retrieved from
338338
# disk somewhere (think sqlalchemy)
339-
class TestObject:
339+
class TestUser:
340340
def __init__(self, username, roles):
341341
self.username = username
342342
self.roles = roles
@@ -348,16 +348,19 @@ def __init__(self, username, roles):
348348
jwt = JWTManager(app)
349349

350350
@jwt.user_claims_loader
351-
def custom_claims(object):
351+
def custom_claims(user):
352352
return {
353-
'roles': object.roles
353+
'roles': user.roles
354354
}
355355

356+
@jwt.user_identity_loader
357+
def user_identity_lookup(user):
358+
return user.username
359+
356360
# Create the token using the complex object
357361
with app.test_request_context():
358-
user = TestObject(username='foo', roles=['bar', 'baz'])
359-
token = create_access_token(identity=user,
360-
identity_lookup=lambda obj: obj.username)
362+
user = TestUser(username='foo', roles=['bar', 'baz'])
363+
token = create_access_token(identity=user)
361364

362365
# Decode the token and make sure the values are set properly
363366
token_data = _decode_jwt(token, app.secret_key, app.config['JWT_ALGORITHM'])

tests/test_jwt_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_class_init(self):
3232
def test_default_user_claims_callback(self):
3333
identity = 'foobar'
3434
m = JWTManager(self.app)
35-
assert m.user_claims_callback(identity) == {}
35+
assert m._user_claims_callback(identity) == {}
3636

3737
def test_default_expired_token_callback(self):
3838
with self.app.test_request_context():
@@ -88,7 +88,7 @@ def test_custom_user_claims_callback(self):
8888
def custom_user_claims(identity):
8989
return {'foo': 'bar'}
9090

91-
assert m.user_claims_callback(identity) == {'foo': 'bar'}
91+
assert m._user_claims_callback(identity) == {'foo': 'bar'}
9292

9393
def test_custom_expired_token_callback(self):
9494
with self.app.test_request_context():

0 commit comments

Comments
 (0)