Skip to content

Commit ea98fcc

Browse files
authored
Pass expired token to expired_token_callback
Refs #220
1 parent 8ba49aa commit ea98fcc

File tree

8 files changed

+88
-30
lines changed

8 files changed

+88
-30
lines changed

examples/loaders.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# this function whenever an expired but otherwise valid access
1414
# token attempts to access an endpoint
1515
@jwt.expired_token_loader
16-
def my_expired_token_callback():
16+
def my_expired_token_callback(expired_token):
17+
token_type = expired_token['type']
1718
return jsonify({
1819
'status': 401,
1920
'sub_status': 42,
20-
'msg': 'The token has expired'
21+
'msg': 'The {} token has expired'.format(token_type)
2122
}), 401
2223

2324

flask_jwt_extended/default_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def default_user_identity_callback(userdata):
3434
return userdata
3535

3636

37-
def default_expired_token_callback():
37+
def default_expired_token_callback(expired_token):
3838
"""
3939
By default, if an expired token attempts to access a protected endpoint,
4040
we return a generic error message with a 401 status

flask_jwt_extended/jwt_manager.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import datetime
2+
from warnings import warn
23

34
from jwt import ExpiredSignatureError, InvalidTokenError, InvalidAudienceError
5+
try:
6+
from flask import _app_ctx_stack as ctx_stack
7+
except ImportError: # pragma: no cover
8+
from flask import _request_ctx_stack as ctx_stack
49

510
from flask_jwt_extended.config import config
611
from flask_jwt_extended.exceptions import (
@@ -90,7 +95,16 @@ def handle_csrf_error(e):
9095

9196
@app.errorhandler(ExpiredSignatureError)
9297
def handle_expired_error(e):
93-
return self._expired_token_callback()
98+
try:
99+
token = ctx_stack.top.expired_jwt
100+
return self._expired_token_callback(token)
101+
except TypeError:
102+
msg = (
103+
"jwt.expired_token_loader callback now takes the expired token "
104+
"as an additional paramter. Example: expired_callback(token)"
105+
)
106+
warn(msg, DeprecationWarning)
107+
return self._expired_token_callback()
94108

95109
@app.errorhandler(InvalidHeaderError)
96110
def handle_invalid_header_error(e):
@@ -244,8 +258,9 @@ def expired_token_loader(self, callback):
244258
245259
{"msg": "Token has expired"}
246260
247-
*HINT*: The callback must be a function that takes **zero** arguments, and returns
248-
a *Flask response*.
261+
*HINT*: The callback must be a function that takes **one** argument,
262+
which is a dictionary containing the data for the expired token, and
263+
and returns a *Flask response*.
249264
"""
250265
self._expired_token_callback = callback
251266
return callback

flask_jwt_extended/tokens.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, user_claims
114114

115115
def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
116116
user_claims_key, csrf_value=None, audience=None,
117-
leeway=0):
117+
leeway=0, allow_expired=False):
118118
"""
119119
Decodes an encoded JWT
120120
@@ -126,12 +126,16 @@ def decode_jwt(encoded_token, secret, algorithm, identity_claim_key,
126126
:param csrf_value: Expected double submit csrf value
127127
:param audience: expected audience in the JWT
128128
:param leeway: optional leeway to add some margin around expiration times
129+
:param allow_expired: Options to ignore exp claim validation in token
129130
:return: Dictionary containing contents of the JWT
130131
"""
132+
options = {}
133+
if allow_expired:
134+
options['verify_exp'] = False
131135

132136
# This call verifies the ext, iat, nbf, and aud claims
133137
data = jwt.decode(encoded_token, secret, algorithms=[algorithm], audience=audience,
134-
leeway=leeway)
138+
leeway=leeway, options=options)
135139

136140
# Make sure that any custom claims we expect in the token are present
137141
if 'jti' not in data:

flask_jwt_extended/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,15 @@ def get_jti(encoded_token):
6565
return decode_token(encoded_token).get('jti')
6666

6767

68-
def decode_token(encoded_token, csrf_value=None):
68+
def decode_token(encoded_token, csrf_value=None, allow_expired=False):
6969
"""
7070
Returns the decoded token (python dict) from an encoded JWT. This does all
7171
the checks to insure that the decoded token is valid before returning it.
7272
7373
:param encoded_token: The encoded JWT to decode into a python dict.
7474
:param csrf_value: Expected CSRF double submit value (optional)
75+
:param allow_expired: Options to ignore exp claim validation in token
76+
:return: Dictionary containing contents of the JWT
7577
"""
7678
jwt_manager = _get_jwt_manager()
7779
unverified_claims = jwt.decode(
@@ -90,6 +92,7 @@ def decode_token(encoded_token, csrf_value=None):
9092
)
9193
warn(msg, DeprecationWarning)
9294
secret = jwt_manager._decode_key_callback(unverified_claims)
95+
9396
return decode_jwt(
9497
encoded_token=encoded_token,
9598
secret=secret,
@@ -98,7 +101,8 @@ def decode_token(encoded_token, csrf_value=None):
98101
user_claims_key=config.user_claims_key,
99102
csrf_value=csrf_value,
100103
audience=config.audience,
101-
leeway=config.leeway
104+
leeway=config.leeway,
105+
allow_expired=allow_expired
102106
)
103107

104108

flask_jwt_extended/view_decorators.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from calendar import timegm
44

55
from werkzeug.exceptions import BadRequest
6+
from jwt import ExpiredSignatureError
67

78
from flask import request
89
try:
@@ -191,7 +192,7 @@ def _decode_jwt_from_headers():
191192
raise InvalidHeaderError(msg)
192193
encoded_token = parts[1]
193194

194-
return decode_token(encoded_token)
195+
return encoded_token, None
195196

196197

197198
def _decode_jwt_from_cookies(request_type):
@@ -213,7 +214,7 @@ def _decode_jwt_from_cookies(request_type):
213214
else:
214215
csrf_value = None
215216

216-
return decode_token(encoded_token, csrf_value=csrf_value)
217+
return encoded_token, csrf_value
217218

218219

219220
def _decode_jwt_from_query_string():
@@ -222,7 +223,7 @@ def _decode_jwt_from_query_string():
222223
if not encoded_token:
223224
raise NoAuthorizationError('Missing "{}" query paramater'.format(query_param))
224225

225-
return decode_token(encoded_token)
226+
return encoded_token, None
226227

227228

228229
def _decode_jwt_from_json(request_type):
@@ -241,29 +242,35 @@ def _decode_jwt_from_json(request_type):
241242
except BadRequest:
242243
raise NoAuthorizationError('Missing "{}" key in json data.'.format(token_key))
243244

244-
return decode_token(encoded_token)
245+
return encoded_token, None
245246

246247

247248
def _decode_jwt_from_request(request_type):
248249
# All the places we can get a JWT from in this request
249-
decode_functions = []
250+
get_encoded_token_functions = []
250251
if config.jwt_in_cookies:
251-
decode_functions.append(lambda: _decode_jwt_from_cookies(request_type))
252+
get_encoded_token_functions.append(lambda: _decode_jwt_from_cookies(request_type))
252253
if config.jwt_in_query_string:
253-
decode_functions.append(_decode_jwt_from_query_string)
254+
get_encoded_token_functions.append(_decode_jwt_from_query_string)
254255
if config.jwt_in_headers:
255-
decode_functions.append(_decode_jwt_from_headers)
256+
get_encoded_token_functions.append(_decode_jwt_from_headers)
256257
if config.jwt_in_json:
257-
decode_functions.append(lambda: _decode_jwt_from_json(request_type))
258+
get_encoded_token_functions.append(lambda: _decode_jwt_from_json(request_type))
258259

259260
# Try to find the token from one of these locations. It only needs to exist
260261
# in one place to be valid (not every location).
261262
errors = []
262263
decoded_token = None
263-
for decode_function in decode_functions:
264+
for get_encoded_token_function in get_encoded_token_functions:
264265
try:
265-
decoded_token = decode_function()
266+
encoded_token, csrf_token = get_encoded_token_function()
267+
decoded_token = decode_token(encoded_token, csrf_token)
266268
break
269+
except ExpiredSignatureError:
270+
# Save the expired token so we can access it in a callback later
271+
expired_data = decode_token(encoded_token, csrf_token, allow_expired=True)
272+
ctx_stack.top.expired_jwt = expired_data
273+
raise
267274
except NoAuthorizationError as e:
268275
errors.append(str(e))
269276

tests/test_decode_tokens.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,12 @@ def default_access_token(app):
4343

4444
@pytest.fixture(scope='function')
4545
def patch_datetime_now(monkeypatch):
46-
47-
DATE_IN_FUTURE = datetime.utcnow() + timedelta(seconds=30)
46+
date_in_future = datetime.utcnow() + timedelta(seconds=30)
4847

4948
class mydatetime(datetime):
5049
@classmethod
5150
def utcnow(cls):
52-
return DATE_IN_FUTURE
53-
51+
return date_in_future
5452
monkeypatch.setattr(__name__ + ".datetime", mydatetime)
5553
monkeypatch.setattr("datetime.datetime", mydatetime)
5654

@@ -116,6 +114,17 @@ def test_expired_token(app):
116114
decode_token(refresh_token)
117115

118116

117+
def test_allow_expired_token(app):
118+
with app.test_request_context():
119+
delta = timedelta(minutes=-5)
120+
access_token = create_access_token('username', expires_delta=delta)
121+
refresh_token = create_refresh_token('username', expires_delta=delta)
122+
for token in (access_token, refresh_token):
123+
decoded = decode_token(token, allow_expired=True)
124+
assert decoded['identity'] == 'username'
125+
assert 'exp' in decoded
126+
127+
119128
def test_never_expire_token(app):
120129
with app.test_request_context():
121130
access_token = create_access_token('username', expires_delta=False)

tests/test_view_decorators.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import warnings
23
from datetime import timedelta
34
from flask import Flask, jsonify
45

@@ -246,14 +247,31 @@ def test_expired_token(app):
246247
assert response.status_code == 401
247248
assert response.get_json() == {'msg': 'Token has expired'}
248249

249-
# Test custom response
250+
# Test depreciated custom response
250251
@jwtM.expired_token_loader
251-
def custom_response():
252+
def depreciated_custom_response():
252253
return jsonify(msg='foobar'), 201
253254

254-
response = test_client.get(url, headers=make_headers(token))
255-
assert response.status_code == 201
256-
assert response.get_json() == {'msg': 'foobar'}
255+
warnings.simplefilter("always")
256+
with warnings.catch_warnings(record=True) as w:
257+
response = test_client.get(url, headers=make_headers(token))
258+
assert response.status_code == 201
259+
assert response.get_json() == {'msg': 'foobar'}
260+
assert w[0].category == DeprecationWarning
261+
262+
# Test new custom response
263+
@jwtM.expired_token_loader
264+
def custom_response(token):
265+
assert token['identity'] == 'username'
266+
assert token['type'] == 'access'
267+
return jsonify(msg='foobar'), 201
268+
269+
warnings.simplefilter("always")
270+
with warnings.catch_warnings(record=True) as w:
271+
response = test_client.get(url, headers=make_headers(token))
272+
assert response.status_code == 201
273+
assert response.get_json() == {'msg': 'foobar'}
274+
assert len(w) == 0
257275

258276

259277
def test_no_token(app):

0 commit comments

Comments
 (0)