Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate responses from exception handlers to avoid inconsistent OpenAPI specification #1407

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions ninja/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,17 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase:
temporal_response = self.api.create_temporal_response(request)
values = self._get_values(request, kw, temporal_response)
result = self.view_func(request, **values)

if isinstance(result, HttpResponseBase):
return result

return self._result_to_response(request, result, temporal_response)
except Exception as e:
if isinstance(e, TypeError) and "required positional argument" in str(e):
msg = "Did you fail to use functools.wraps() in a decorator?"
msg = f"{e.args[0]}: {msg}" if e.args else msg
e.args = (msg,) + e.args[1:]
return self.api.on_exception(request, e)
return self._on_exception(request, e)

def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None:
self.api = api
Expand Down Expand Up @@ -158,6 +162,12 @@ def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None:
if router.tags is not None:
self.tags = router.tags

def _on_exception(self, request: HttpRequest, exc: Exception) -> HttpResponse:
temporal_response = self.api.create_temporal_response(request)
result = self.api.on_exception(request, exc)

return self._result_to_response(request, result, temporal_response)

def _set_auth(
self, auth: Optional[Union[Sequence[Callable], Callable, object]]
) -> None:
Expand Down Expand Up @@ -196,12 +206,12 @@ def _run_authentication(self, request: HttpRequest) -> Optional[HttpResponse]:
else:
result = callback(request)
except Exception as exc:
return self.api.on_exception(request, exc)
return self._on_exception(request, exc)

if result:
request.auth = result # type: ignore
return None
return self.api.on_exception(request, AuthenticationError())
return self._on_exception(request, AuthenticationError())

def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]:
throttle_durations = []
Expand All @@ -216,19 +226,19 @@ def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]:
]

duration = max(durations, default=None)
return self.api.on_exception(request, Throttled(wait=duration)) # type: ignore
return self._on_exception(request, Throttled(wait=duration)) # type: ignore
return None

def _result_to_response(
self, request: HttpRequest, result: Any, temporal_response: HttpResponse
) -> HttpResponseBase:
) -> HttpResponse:
"""
The protocol for results
- if HttpResponse - returns as is
- if tuple with 2 elements - means http_code + body
- otherwise it's a body
"""
if isinstance(result, HttpResponseBase):
if isinstance(result, HttpResponse):
return result

status: int = 200
Expand Down Expand Up @@ -338,7 +348,7 @@ async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # typ
result = await self.view_func(request, **values)
return self._result_to_response(request, result, temporal_response)
except Exception as e:
return self.api.on_exception(request, e)
return self._on_exception(request, e)

async def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore
"Runs security checks for each operation"
Expand Down Expand Up @@ -376,12 +386,12 @@ async def _run_authentication(self, request: HttpRequest) -> Optional[HttpRespon
else:
result = callback(request)
except Exception as exc:
return self.api.on_exception(request, exc)
return self._on_exception(request, exc)

if result:
request.auth = result # type: ignore
return None
return self.api.on_exception(request, AuthenticationError())
return self._on_exception(request, AuthenticationError())


class PathView:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from django.http import Http404
from pydantic import ValidationError

from ninja import NinjaAPI, Schema
from ninja.errors import ConfigError
from ninja.testing import TestAsyncClient, TestClient

api = NinjaAPI()
Expand Down Expand Up @@ -97,3 +99,33 @@ def thrower(request):

with pytest.raises(RuntimeError):
client.get("/error")


def test_improper_response_body_from_exception_handler():
@api.exception_handler(RuntimeError)
def on_runtime_error(request, exc):
return 418, {"payload": "non-proper"}

@api.get("/error", response={418: Payload})
def thrower(request):
raise RuntimeError

client = TestClient(api)

with pytest.raises(ValidationError):
client.get("/error")


def test_non_configured_status_code_from_exception_handler():
@api.exception_handler(RuntimeError)
def on_runtime_error(request, exc):
return 410, Payload(test=1234)

@api.get("/error", response={418: Payload})
def thrower(request):
raise RuntimeError

client = TestClient(api)

with pytest.raises(ConfigError):
client.get("/error")