diff --git a/ninja/operation.py b/ninja/operation.py index 0e4976d6e..6b5a15360 100644 --- a/ninja/operation.py +++ b/ninja/operation.py @@ -123,13 +123,17 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: temporal_response = self.api.create_temporal_response(request) values = self._get_values(request, kw, temporal_response) result = self.view_func(request, **values) + + if isinstance(result, HttpResponseBase): + return result + return self._result_to_response(request, result, temporal_response) except Exception as e: if isinstance(e, TypeError) and "required positional argument" in str(e): msg = "Did you fail to use functools.wraps() in a decorator?" msg = f"{e.args[0]}: {msg}" if e.args else msg e.args = (msg,) + e.args[1:] - return self.api.on_exception(request, e) + return self._on_exception(request, e) def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None: self.api = api @@ -158,6 +162,12 @@ def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None: if router.tags is not None: self.tags = router.tags + def _on_exception(self, request: HttpRequest, exc: Exception) -> HttpResponse: + temporal_response = self.api.create_temporal_response(request) + result = self.api.on_exception(request, exc) + + return self._result_to_response(request, result, temporal_response) + def _set_auth( self, auth: Optional[Union[Sequence[Callable], Callable, object]] ) -> None: @@ -196,12 +206,12 @@ def _run_authentication(self, request: HttpRequest) -> Optional[HttpResponse]: else: result = callback(request) except Exception as exc: - return self.api.on_exception(request, exc) + return self._on_exception(request, exc) if result: request.auth = result # type: ignore return None - return self.api.on_exception(request, AuthenticationError()) + return self._on_exception(request, AuthenticationError()) def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]: throttle_durations = [] @@ -216,19 +226,19 @@ def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]: ] duration = max(durations, default=None) - return self.api.on_exception(request, Throttled(wait=duration)) # type: ignore + return self._on_exception(request, Throttled(wait=duration)) # type: ignore return None def _result_to_response( self, request: HttpRequest, result: Any, temporal_response: HttpResponse - ) -> HttpResponseBase: + ) -> HttpResponse: """ The protocol for results - if HttpResponse - returns as is - if tuple with 2 elements - means http_code + body - otherwise it's a body """ - if isinstance(result, HttpResponseBase): + if isinstance(result, HttpResponse): return result status: int = 200 @@ -338,7 +348,7 @@ async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # typ result = await self.view_func(request, **values) return self._result_to_response(request, result, temporal_response) except Exception as e: - return self.api.on_exception(request, e) + return self._on_exception(request, e) async def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore "Runs security checks for each operation" @@ -376,12 +386,12 @@ async def _run_authentication(self, request: HttpRequest) -> Optional[HttpRespon else: result = callback(request) except Exception as exc: - return self.api.on_exception(request, exc) + return self._on_exception(request, exc) if result: request.auth = result # type: ignore return None - return self.api.on_exception(request, AuthenticationError()) + return self._on_exception(request, AuthenticationError()) class PathView: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index e3e314fa1..1d27f25e0 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,7 +1,9 @@ import pytest from django.http import Http404 +from pydantic import ValidationError from ninja import NinjaAPI, Schema +from ninja.errors import ConfigError from ninja.testing import TestAsyncClient, TestClient api = NinjaAPI() @@ -97,3 +99,33 @@ def thrower(request): with pytest.raises(RuntimeError): client.get("/error") + + +def test_improper_response_body_from_exception_handler(): + @api.exception_handler(RuntimeError) + def on_runtime_error(request, exc): + return 418, {"payload": "non-proper"} + + @api.get("/error", response={418: Payload}) + def thrower(request): + raise RuntimeError + + client = TestClient(api) + + with pytest.raises(ValidationError): + client.get("/error") + + +def test_non_configured_status_code_from_exception_handler(): + @api.exception_handler(RuntimeError) + def on_runtime_error(request, exc): + return 410, Payload(test=1234) + + @api.get("/error", response={418: Payload}) + def thrower(request): + raise RuntimeError + + client = TestClient(api) + + with pytest.raises(ConfigError): + client.get("/error")