16
16
from jose .jwt import encode as jwt_encode
17
17
from starlette .authentication import AuthCredentials
18
18
from starlette .authentication import AuthenticationBackend
19
+ from starlette .authentication import AuthenticationError
19
20
from starlette .authentication import BaseUser
20
21
from starlette .middleware .authentication import AuthenticationMiddleware
22
+ from starlette .requests import HTTPConnection
21
23
from starlette .requests import Request
22
- from starlette .responses import PlainTextResponse
24
+ from starlette .responses import Response
23
25
from starlette .types import ASGIApp
24
26
from starlette .types import Receive
25
27
from starlette .types import Scope
28
30
from .claims import Claims
29
31
from .config import OAuth2Config
30
32
from .core import OAuth2Core
31
- from .exceptions import OAuth2AuthenticationError
32
33
33
34
34
35
class Auth (AuthCredentials ):
@@ -108,9 +109,12 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
108
109
if not scheme or not param :
109
110
return Auth (), User ()
110
111
111
- token_data = Auth .jwt_decode (param )
112
+ try :
113
+ token_data = Auth .jwt_decode (param )
114
+ except JOSEError as e :
115
+ raise AuthenticationError (str (e ))
112
116
if token_data ["exp" ] and token_data ["exp" ] < int (datetime .now (timezone .utc ).timestamp ()):
113
- raise OAuth2AuthenticationError ( 401 , "Token expired" )
117
+ raise AuthenticationError ( "Token expired" )
114
118
115
119
user = User (token_data )
116
120
auth = Auth (user .pop ("scope" , []))
@@ -135,7 +139,7 @@ def __init__(
135
139
app : ASGIApp ,
136
140
config : Union [OAuth2Config , dict ],
137
141
callback : Callable [[Auth , User ], Union [Awaitable [None ], None ]] = None ,
138
- ** kwargs , # AuthenticationMiddleware kwargs
142
+ on_error : Optional [ Callable [[ HTTPConnection , AuthenticationError ], Response ]] = None ,
139
143
) -> None :
140
144
"""Initiates the middleware with the given configuration.
141
145
@@ -148,13 +152,10 @@ def __init__(
148
152
elif not isinstance (config , OAuth2Config ):
149
153
raise TypeError ("config is not a valid type" )
150
154
self .default_application_middleware = app
151
- self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), ** kwargs )
155
+ on_error = on_error or AuthenticationMiddleware .default_on_error
156
+ self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), on_error = on_error )
152
157
153
158
async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
154
159
if scope ["type" ] == "http" :
155
- try :
156
- return await self .auth_middleware (scope , receive , send )
157
- except (JOSEError , Exception ) as e :
158
- middleware = PlainTextResponse (str (e ), status_code = 401 )
159
- return await middleware (scope , receive , send )
160
+ return await self .auth_middleware (scope , receive , send )
160
161
await self .default_application_middleware (scope , receive , send )
0 commit comments