Skip to content

Commit e388257

Browse files
Fix exception overhandling in middleware (GH-44)
Co-authored-by: David García Garzón <[email protected]>
2 parents c7ca1ce + ad31cba commit e388257

File tree

5 files changed

+88
-15
lines changed

5 files changed

+88
-15
lines changed

docs/integration/integration.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,24 @@ section covers its integration into a FastAPI app.
1111

1212
The `OAuth2Middleware` is an authentication middleware which means that its usage makes the `user` and `auth` attributes
1313
available in the [request](https://www.starlette.io/requests/) context. It has a mandatory argument `config` of
14-
[`OAuth2Config`](/integration/configuration#oauth2config) instance that has been discussed at the previous section and
15-
an optional argument `callback` which is a callable that is called when the authentication succeeds.
14+
[`OAuth2Config`](/integration/configuration#oauth2config) instance that has been discussed in the previous section and
15+
optional arguments `callback` and `on_error` that accept callables as values and are called when the authentication
16+
succeeds and fails correspondingly.
1617

1718
```python
1819
app: FastAPI
1920

2021
def on_auth_success(auth: Auth, user: User):
2122
"""This could be async function as well."""
2223

24+
def on_auth_error(conn: HTTPConnection, exc: Exception) -> Response:
25+
return JSONResponse({"detail": str(exc)}, status_code=400)
26+
2327
app.add_middleware(
2428
OAuth2Middleware,
2529
config=OAuth2Config(...),
2630
callback=on_auth_success,
31+
on_error=on_auth_error,
2732
)
2833
```
2934

docs/references/tutorials.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ async def error_handler(request: Request, exc: OAuth2AuthenticationError):
115115
return RedirectResponse(url="/login", status_code=303)
116116
```
117117

118-
The complete list of exceptions is the following.
118+
The complete list of exceptions raised by the middleware is the following.
119119

120120
- `OAuth2Error` - Base exception for all errors raised by the FastAPI OAuth2 library.
121121
- `OAuth2AuthenticationError` - An exception is raised when the authentication fails.

src/fastapi_oauth2/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0"
1+
__version__ = "1.1.0"

src/fastapi_oauth2/middleware.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
from jose.jwt import encode as jwt_encode
1717
from starlette.authentication import AuthCredentials
1818
from starlette.authentication import AuthenticationBackend
19+
from starlette.authentication import AuthenticationError
1920
from starlette.authentication import BaseUser
2021
from starlette.middleware.authentication import AuthenticationMiddleware
22+
from starlette.requests import HTTPConnection
2123
from starlette.requests import Request
22-
from starlette.responses import PlainTextResponse
24+
from starlette.responses import Response
2325
from starlette.types import ASGIApp
2426
from starlette.types import Receive
2527
from starlette.types import Scope
@@ -28,7 +30,6 @@
2830
from .claims import Claims
2931
from .config import OAuth2Config
3032
from .core import OAuth2Core
31-
from .exceptions import OAuth2AuthenticationError
3233

3334

3435
class Auth(AuthCredentials):
@@ -108,9 +109,12 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
108109
if not scheme or not param:
109110
return Auth(), User()
110111

111-
token_data = Auth.jwt_decode(param)
112+
try:
113+
token_data = Auth.jwt_decode(param)
114+
except JOSEError as e:
115+
raise AuthenticationError(str(e))
112116
if token_data["exp"] and token_data["exp"] < int(datetime.now(timezone.utc).timestamp()):
113-
raise OAuth2AuthenticationError(401, "Token expired")
117+
raise AuthenticationError("Token expired")
114118

115119
user = User(token_data)
116120
auth = Auth(user.pop("scope", []))
@@ -135,7 +139,7 @@ def __init__(
135139
app: ASGIApp,
136140
config: Union[OAuth2Config, dict],
137141
callback: Callable[[Auth, User], Union[Awaitable[None], None]] = None,
138-
**kwargs, # AuthenticationMiddleware kwargs
142+
on_error: Optional[Callable[[HTTPConnection, AuthenticationError], Response]] = None,
139143
) -> None:
140144
"""Initiates the middleware with the given configuration.
141145
@@ -148,13 +152,10 @@ def __init__(
148152
elif not isinstance(config, OAuth2Config):
149153
raise TypeError("config is not a valid type")
150154
self.default_application_middleware = app
151-
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), **kwargs)
155+
on_error = on_error or AuthenticationMiddleware.default_on_error
156+
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), on_error=on_error)
152157

153158
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
154159
if scope["type"] == "http":
155-
try:
156-
return await self.auth_middleware(scope, receive, send)
157-
except (JOSEError, Exception) as e:
158-
middleware = PlainTextResponse(str(e), status_code=401)
159-
return await middleware(scope, receive, send)
160+
return await self.auth_middleware(scope, receive, send)
160161
await self.default_application_middleware(scope, receive, send)

tests/test_middleware.py

+67
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
2+
from fastapi.responses import JSONResponse
23
from httpx import AsyncClient
4+
from jose import jwt
35

46

57
@pytest.mark.anyio
@@ -26,3 +28,68 @@ async def test_middleware_on_logout(get_app):
2628

2729
response = await client.get("/user")
2830
assert response.status_code == 403 # Forbidden
31+
32+
33+
@pytest.mark.anyio
34+
async def test_middleware_do_not_interfere_user_errors(get_app):
35+
app = get_app()
36+
37+
@app.get("/unexpected_error")
38+
def my_entry_point():
39+
raise NameError # Intended code error
40+
41+
async with AsyncClient(app=app, base_url="http://test") as client:
42+
with pytest.raises(NameError):
43+
await client.get("/unexpected_error")
44+
45+
46+
@pytest.mark.anyio
47+
async def test_middleware_ignores_custom_exceptions(get_app):
48+
class MyCustomException(Exception):
49+
pass
50+
51+
app = get_app()
52+
53+
@app.get("/custom_exception")
54+
def my_entry_point():
55+
raise MyCustomException()
56+
57+
async with AsyncClient(app=app, base_url="http://test") as client:
58+
with pytest.raises(MyCustomException):
59+
await client.get("/custom_exception")
60+
61+
62+
@pytest.mark.anyio
63+
async def test_middleware_ignores_handled_custom_exceptions(get_app):
64+
class MyHandledException(Exception):
65+
pass
66+
67+
app = get_app()
68+
69+
@app.exception_handler(MyHandledException)
70+
async def unicorn_exception_handler(request, exc):
71+
return JSONResponse(
72+
status_code=418,
73+
content={"details": "I am a custom Teapot!"},
74+
)
75+
76+
@app.get("/handled_exception")
77+
def my_entry_point():
78+
raise MyHandledException()
79+
80+
async with AsyncClient(app=app, base_url="http://test") as client:
81+
response = await client.get("/handled_exception")
82+
assert response.status_code == 418 # I am a teapot!
83+
assert response.json() == {"details": "I am a custom Teapot!"}
84+
85+
86+
@pytest.mark.anyio
87+
async def test_middleware_reports_invalid_jwt(get_app):
88+
async with AsyncClient(app=get_app(with_ssr=False), base_url="http://test") as client:
89+
# Insert a bad token instead
90+
badtoken = jwt.encode({"bad": "token"}, "badsecret", "HS256")
91+
client.cookies.update(dict(Authorization=f"Bearer: {badtoken}"))
92+
93+
response = await client.get("/user")
94+
assert response.status_code == 400
95+
assert response.text == "Signature verification failed."

0 commit comments

Comments
 (0)