From 3708dd3bb981612e119473038447262d68f4ed72 Mon Sep 17 00:00:00 2001 From: Jean-Edouard BOULANGER Date: Fri, 4 Aug 2023 10:57:46 +0200 Subject: [PATCH 1/3] Fully support list responses --- Makefile | 2 +- spectree/_types.py | 4 +- spectree/plugins/falcon_plugin.py | 19 +++++-- spectree/plugins/flask_plugin.py | 11 +++- spectree/plugins/quart_plugin.py | 9 +++ spectree/plugins/starlette_plugin.py | 11 +++- spectree/response.py | 25 +++++++-- .../test_plugin_spec[falcon][full_spec].json | 56 +++++++++++++++++++ .../test_plugin_spec[flask][full_spec].json | 56 +++++++++++++++++++ ...ugin_spec[flask_blueprint][full_spec].json | 56 +++++++++++++++++++ ...st_plugin_spec[flask_view][full_spec].json | 56 +++++++++++++++++++ ...est_plugin_spec[starlette][full_spec].json | 56 +++++++++++++++++++ tests/flask_imports/__init__.py | 2 + tests/flask_imports/dry_plugin_flask.py | 10 ++++ tests/quart_imports/dry_plugin_quart.py | 12 ++++ tests/test_plugin_falcon.py | 24 ++++++++ tests/test_plugin_falcon_asgi.py | 20 +++++++ tests/test_plugin_flask.py | 9 +++ tests/test_plugin_flask_blueprint.py | 9 +++ tests/test_plugin_flask_view.py | 15 +++++ tests/test_plugin_quart.py | 9 +++ tests/test_plugin_starlette.py | 21 +++++++ tests/test_response.py | 2 + 23 files changed, 479 insertions(+), 15 deletions(-) diff --git a/Makefile b/Makefile index a12db61b..ac5f5f09 100644 --- a/Makefile +++ b/Makefile @@ -45,7 +45,7 @@ format: lint: isort --check --diff --project=spectree ${SOURCE_FILES} black --check --diff ${SOURCE_FILES} - flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=15 + flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=16 mypy --install-types --non-interactive ${MYPY_SOURCE_FILES} .PHONY: test doc diff --git a/spectree/_types.py b/spectree/_types.py index b572fad5..7aae8809 100644 --- a/spectree/_types.py +++ b/spectree/_types.py @@ -8,6 +8,7 @@ Optional, Sequence, Type, + TypeVar, Union, ) @@ -15,7 +16,8 @@ from ._pydantic import BaseModel -ModelType = Type[BaseModel] +BaseModelSubclassType = TypeVar("BaseModelSubclassType", bound=BaseModel) +ModelType = Type[BaseModelSubclassType] OptionalModelType = Optional[ModelType] NamingStrategy = Callable[[ModelType], str] NestedNamingStrategy = Callable[[str, str], str] diff --git a/spectree/plugins/falcon_plugin.py b/spectree/plugins/falcon_plugin.py index 07b5d13f..804219c7 100644 --- a/spectree/plugins/falcon_plugin.py +++ b/spectree/plugins/falcon_plugin.py @@ -6,7 +6,7 @@ from falcon import HTTP_400, HTTP_415, HTTPError from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN -from .._pydantic import ValidationError +from .._pydantic import BaseModel, ValidationError from .._types import ModelType from ..response import Response from .base import BasePlugin @@ -227,17 +227,24 @@ def validate( func(*args, **kwargs) if resp and resp.has_model(): - model = resp.find_model(_resp.status[:3]) - if model and isinstance(_resp.media, model): - _resp.media = _resp.media.dict() + model = _resp.media + status = int(_resp.status[:3]) + expect_model = resp.find_model(status) + if resp.expect_list_result(status) and isinstance(model, list): + _resp.media = [ + (entry.dict() if isinstance(entry, BaseModel) else entry) + for entry in model + ] + elif expect_model and isinstance(_resp.media, expect_model): + _resp.media = model.dict() skip_validation = True if self._data_set_manually(_resp): skip_validation = True - if model and not skip_validation: + if expect_model and not skip_validation: try: - model.parse_obj(_resp.media) + expect_model.parse_obj(_resp.media) except ValidationError as err: resp_validation_error = err _resp.status = HTTP_500 diff --git a/spectree/plugins/flask_plugin.py b/spectree/plugins/flask_plugin.py index 0604a5ab..8f3b5937 100644 --- a/spectree/plugins/flask_plugin.py +++ b/spectree/plugins/flask_plugin.py @@ -210,7 +210,16 @@ def validate( if resp: expect_model = resp.find_model(status) - if expect_model and isinstance(model, expect_model): + if resp.expect_list_result(status) and isinstance(model, list): + result = ( + [ + (entry.dict() if isinstance(entry, BaseModel) else entry) + for entry in model + ], + status, + *rest, + ) + elif expect_model and isinstance(model, expect_model): skip_validation = True result = (model.dict(), status, *rest) diff --git a/spectree/plugins/quart_plugin.py b/spectree/plugins/quart_plugin.py index de891c2e..5dc5d990 100644 --- a/spectree/plugins/quart_plugin.py +++ b/spectree/plugins/quart_plugin.py @@ -222,6 +222,15 @@ async def validate( if resp: expect_model = resp.find_model(status) + if resp.expect_list_result(status) and isinstance(model, list): + result = ( + [ + (entry.dict() if isinstance(entry, BaseModel) else entry) + for entry in model + ], + status, + *rest, + ) if expect_model and isinstance(model, expect_model): skip_validation = True result = (model.dict(), status, *rest) diff --git a/spectree/plugins/starlette_plugin.py b/spectree/plugins/starlette_plugin.py index 938bad75..d8fc115a 100644 --- a/spectree/plugins/starlette_plugin.py +++ b/spectree/plugins/starlette_plugin.py @@ -9,7 +9,7 @@ from starlette.responses import HTMLResponse, JSONResponse from starlette.routing import compile_path -from .._pydantic import ValidationError +from .._pydantic import BaseModel, ValidationError from .._types import ModelType from ..response import Response from .base import BasePlugin, Context @@ -22,7 +22,14 @@ def PydanticResponse(content): class _PydanticResponse(JSONResponse): def render(self, content) -> bytes: self._model_class = content.__class__ - return super().render(content.dict()) + return super().render( + [ + (entry.dict() if isinstance(entry, BaseModel) else entry) + for entry in content + ] + if isinstance(content, list) + else content.dict() + ) return _PydanticResponse(content) diff --git a/spectree/response.py b/spectree/response.py index 6b8e69e5..71412831 100644 --- a/spectree/response.py +++ b/spectree/response.py @@ -1,8 +1,8 @@ from http import HTTPStatus -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union from ._pydantic import BaseModel -from ._types import ModelType, NamingStrategy, OptionalModelType +from ._types import BaseModelSubclassType, ModelType, NamingStrategy, OptionalModelType from .utils import gen_list_model, get_model_key, parse_code # according to https://tools.ietf.org/html/rfc2616#section-10 @@ -30,22 +30,30 @@ class Response: examples: + >>> from typing import List >>> from spectree.response import Response >>> from pydantic import BaseModel ... >>> class User(BaseModel): ... id: int ... - >>> response = Response(HTTP_200) + >>> response = Response("HTTP_200") >>> response = Response(HTTP_200=None) >>> response = Response(HTTP_200=User) >>> response = Response(HTTP_200=(User, "status code description")) + >>> response = Response(HTTP_200=List[User]) + >>> response = Response(HTTP_200=(List[User], "status code description")) """ def __init__( self, *codes: str, - **code_models: Union[OptionalModelType, Tuple[OptionalModelType, str]], + **code_models: Union[ + OptionalModelType, + Tuple[OptionalModelType, str], + Type[List[BaseModelSubclassType]], + Tuple[Type[List[BaseModelSubclassType]], str], + ], ) -> None: self.codes: List[str] = [] @@ -55,6 +63,7 @@ def __init__( self.code_models: Dict[str, ModelType] = {} self.code_descriptions: Dict[str, Optional[str]] = {} + self.codes_expecting_list_result: Set[str] = set() for code, model_and_description in code_models.items(): assert code in DEFAULT_CODE_DESC, "invalid HTTP status code" description: Optional[str] = None @@ -73,6 +82,7 @@ def __init__( if origin_type is list or origin_type is List: # type is List[BaseModel] model = gen_list_model(getattr(model, "__args__")[0]) + self.codes_expecting_list_result.add(code) assert issubclass(model, BaseModel), "invalid `pydantic.BaseModel`" assert description is None or isinstance( description, str @@ -119,6 +129,13 @@ def find_model(self, code: int) -> OptionalModelType: """ return self.code_models.get(f"HTTP_{code}") + def expect_list_result(self, code: int) -> bool: + """Check whether a specific HTTP code expects a list result. + + :param code: Status code string, format('HTTP_[0-9]_{3}'), 'HTTP_200'. + """ + return f"HTTP_{code}" in self.codes_expecting_list_result + def get_code_description(self, code: str) -> str: """Get the description of the given status code. diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json index 965ce574..9dacbf8f 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json @@ -67,6 +67,31 @@ "title": "JSON", "type": "object" }, + "JSONList.a9993e3": { + "items": { + "$ref": "#/components/schemas/JSONList.a9993e3.JSON" + }, + "title": "JSONList", + "type": "array" + }, + "JSONList.a9993e3.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "ListJSON.7068f62": { "items": { "$ref": "#/components/schemas/ListJSON.7068f62.JSON" @@ -327,6 +352,37 @@ "tags": [] } }, + "/api/return_list": { + "get": { + "description": "", + "operationId": "get__api_return_list", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/JSONList.a9993e3" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "on_get ", + "tags": [] + } + }, "/api/user/{name}": { "get": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json index 99c7b293..0559e6ba 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json @@ -85,6 +85,31 @@ "title": "JSON", "type": "object" }, + "JSONList.a9993e3": { + "items": { + "$ref": "#/components/schemas/JSONList.a9993e3.JSON" + }, + "title": "JSONList", + "type": "array" + }, + "JSONList.a9993e3.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "ListJSON.7068f62": { "items": { "$ref": "#/components/schemas/ListJSON.7068f62.JSON" @@ -280,6 +305,37 @@ "tags": [] } }, + "/api/return_list": { + "get": { + "description": "", + "operationId": "get__api_return_list", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/JSONList.a9993e3" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "return_list ", + "tags": [] + } + }, "/api/user/{name}": { "post": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json index 314b5841..bea03587 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json @@ -85,6 +85,31 @@ "title": "JSON", "type": "object" }, + "JSONList.a9993e3": { + "items": { + "$ref": "#/components/schemas/JSONList.a9993e3.JSON" + }, + "title": "JSONList", + "type": "array" + }, + "JSONList.a9993e3.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "ListJSON.7068f62": { "items": { "$ref": "#/components/schemas/ListJSON.7068f62.JSON" @@ -280,6 +305,37 @@ "tags": [] } }, + "/api/return_list": { + "get": { + "description": "", + "operationId": "get__api_return_list", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/JSONList.a9993e3" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "return_list ", + "tags": [] + } + }, "/api/user/{name}": { "post": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json index c230237a..edac4522 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json @@ -85,6 +85,31 @@ "title": "JSON", "type": "object" }, + "JSONList.a9993e3": { + "items": { + "$ref": "#/components/schemas/JSONList.a9993e3.JSON" + }, + "title": "JSONList", + "type": "array" + }, + "JSONList.a9993e3.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "ListJSON.7068f62": { "items": { "$ref": "#/components/schemas/ListJSON.7068f62.JSON" @@ -285,6 +310,37 @@ "tags": [] } }, + "/api/return_list": { + "get": { + "description": "", + "operationId": "get__api_return_list", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/JSONList.a9993e3" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "get ", + "tags": [] + } + }, "/api/user/{name}": { "post": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json index 8dbf6649..332a8dc6 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json @@ -67,6 +67,31 @@ "title": "JSON", "type": "object" }, + "JSONList.a9993e3": { + "items": { + "$ref": "#/components/schemas/JSONList.a9993e3.JSON" + }, + "title": "JSONList", + "type": "array" + }, + "JSONList.a9993e3.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "ListJSON.7068f62": { "items": { "$ref": "#/components/schemas/ListJSON.7068f62.JSON" @@ -290,6 +315,37 @@ "tags": [] } }, + "/api/return_list": { + "get": { + "description": "", + "operationId": "get__api_return_list", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/JSONList.a9993e3" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "return_list ", + "tags": [] + } + }, "/api/user/{name}": { "post": { "description": "", diff --git a/tests/flask_imports/__init__.py b/tests/flask_imports/__init__.py index 2a05aa54..e98d635f 100644 --- a/tests/flask_imports/__init__.py +++ b/tests/flask_imports/__init__.py @@ -2,6 +2,7 @@ test_flask_doc, test_flask_list_json_request, test_flask_no_response, + test_flask_return_list_request, test_flask_return_model, test_flask_skip_validation, test_flask_upload_file, @@ -18,4 +19,5 @@ "test_flask_no_response", "test_flask_upload_file", "test_flask_list_json_request", + "test_flask_return_list_request", ] diff --git a/tests/flask_imports/dry_plugin_flask.py b/tests/flask_imports/dry_plugin_flask.py index 3144a4fc..7146c3f2 100644 --- a/tests/flask_imports/dry_plugin_flask.py +++ b/tests/flask_imports/dry_plugin_flask.py @@ -161,6 +161,16 @@ def test_flask_list_json_request(client): assert resp.status_code == 200, resp.data +@pytest.mark.parametrize("pre_serialize", [False, True]) +def test_flask_return_list_request(client, pre_serialize: bool): + resp = client.get(f"/api/return_list?pre_serialize={int(pre_serialize)}") + assert resp.status_code == 200 + assert resp.json == [ + {"name": "user1", "limit": 1}, + {"name": "user2", "limit": 2}, + ] + + def test_flask_upload_file(client): file_content = "abcdef" data = {"file": (io.BytesIO(file_content.encode("utf-8")), "test.txt")} diff --git a/tests/quart_imports/dry_plugin_quart.py b/tests/quart_imports/dry_plugin_quart.py index 2ddca24a..57098d16 100644 --- a/tests/quart_imports/dry_plugin_quart.py +++ b/tests/quart_imports/dry_plugin_quart.py @@ -162,3 +162,15 @@ def test_quart_list_json_request(client): client.post("/api/list_json", json=[{"name": "foo", "limit": 1}]) ) assert resp.status_code == 200 + + +@pytest.mark.parametrize("pre_serialize", [False, True]) +def test_quart_return_list_request(client, pre_serialize: bool): + resp = asyncio.run( + client.get(f"/api/return_list?pre_serialize={int(pre_serialize)}") + ) + assert resp.status_code == 200 + assert resp.json == [ + {"name": "user1", "limit": 1}, + {"name": "user2", "limit": 2}, + ] diff --git a/tests/test_plugin_falcon.py b/tests/test_plugin_falcon.py index e91a9324..9f2890fc 100644 --- a/tests/test_plugin_falcon.py +++ b/tests/test_plugin_falcon.py @@ -1,4 +1,5 @@ from random import randint +from typing import List import pytest from falcon import App, testing @@ -196,6 +197,16 @@ def on_post(self, req, resp, json: ListJSON): pass +class ReturnListView: + name = "return list request view" + + @api.validate(resp=Response(HTTP_200=List[JSON])) + def on_get(self, req, resp): + pre_serialize = bool(int(req.params.get("pre_serialize", 0))) + data = [JSON(name="user1", limit=1), JSON(name="user2", limit=2)] + resp.media = [entry.dict() if pre_serialize else entry for entry in data] + + class ViewWithCustomSerializer: name = "view with custom serializer" @@ -222,6 +233,7 @@ def on_post(self, req, resp): app.add_route("/api/no_response", NoResponseView()) app.add_route("/api/file_upload", FileUploadView()) app.add_route("/api/list_json", ListJsonView()) +app.add_route("/api/return_list", ReturnListView()) app.add_route("/api/custom_serializer", ViewWithCustomSerializer()) api.register(app) @@ -324,6 +336,18 @@ def test_falcon_list_json_request_sync(client): assert resp.status_code == 200 +@pytest.mark.parametrize("pre_serialize", [False, True]) +def test_falcon_return_list_request_sync(client, pre_serialize: bool): + resp = client.simulate_request( + "GET", f"/api/return_list?pre_serialize={int(pre_serialize)}" + ) + assert resp.status_code == 200 + assert resp.json == [ + {"name": "user1", "limit": 1}, + {"name": "user2", "limit": 2}, + ] + + @pytest.fixture def test_client_and_api(request): api_args = ["falcon"] diff --git a/tests/test_plugin_falcon_asgi.py b/tests/test_plugin_falcon_asgi.py index 056555b5..7a51713a 100644 --- a/tests/test_plugin_falcon_asgi.py +++ b/tests/test_plugin_falcon_asgi.py @@ -1,4 +1,5 @@ from random import randint +from typing import List import pytest from falcon import testing @@ -121,6 +122,15 @@ async def on_post(self, req, resp, json: ListJSON): pass +class ReturnListView: + name = "return list request view" + + @api.validate(resp=Response(HTTP_200=List[JSON])) + async def on_get(self, req, resp): + data = [JSON(name="user1", limit=1), JSON(name="user2", limit=2)] + resp.media = [entry.dict() for entry in data] + + class FileUploadView: name = "file upload view" @@ -156,6 +166,7 @@ async def on_post(self, req, resp): app.add_route("/api/no_response", NoResponseView()) app.add_route("/api/file_upload", FileUploadView()) app.add_route("/api/list_json", ListJsonView()) +app.add_route("/api/return_list", ReturnListView()) app.add_route("/api/custom_serializer", ViewWithCustomSerializer()) api.register(app) @@ -189,6 +200,15 @@ def test_falcon_list_json_request_async(client): assert resp.status_code == 200 +def test_falcon_return_list_request_async(client): + resp = client.simulate_request("GET", "/api/return_list") + assert resp.status_code == 200 + assert resp.json == [ + {"name": "user1", "limit": 1}, + {"name": "user2", "limit": 2}, + ] + + def test_falcon_validate(client): resp = client.simulate_request( "GET", "/ping", headers={"Content-Type": "text/plain"} diff --git a/tests/test_plugin_flask.py b/tests/test_plugin_flask.py index a5df4bee..5d2fa974 100644 --- a/tests/test_plugin_flask.py +++ b/tests/test_plugin_flask.py @@ -1,4 +1,5 @@ from random import randint +from typing import List import pytest from flask import Flask, jsonify, request @@ -168,6 +169,14 @@ def json_list(): return {} +@app.route("/api/return_list", methods=["GET"]) +@api.validate(resp=Response(HTTP_200=List[JSON])) +def return_list(): + pre_serialize = bool(int(request.args.get("pre_serialize", default=0))) + data = [JSON(name="user1", limit=1), JSON(name="user2", limit=2)] + return [entry.dict() if pre_serialize else entry for entry in data] + + # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` # without app_context. diff --git a/tests/test_plugin_flask_blueprint.py b/tests/test_plugin_flask_blueprint.py index 619b567a..33c896c6 100644 --- a/tests/test_plugin_flask_blueprint.py +++ b/tests/test_plugin_flask_blueprint.py @@ -1,4 +1,5 @@ from random import randint +from typing import List import pytest from flask import Blueprint, Flask, jsonify, request @@ -157,6 +158,14 @@ def list_json(): return {} +@app.route("/api/return_list", methods=["GET"]) +@api.validate(resp=Response(HTTP_200=List[JSON])) +def return_list(): + pre_serialize = bool(int(request.args.get("pre_serialize", default=0))) + data = [JSON(name="user1", limit=1), JSON(name="user2", limit=2)] + return [entry.dict() if pre_serialize else entry for entry in data] + + api.register(app) flask_app = Flask(__name__) diff --git a/tests/test_plugin_flask_view.py b/tests/test_plugin_flask_view.py index efb2e911..03e645ca 100644 --- a/tests/test_plugin_flask_view.py +++ b/tests/test_plugin_flask_view.py @@ -1,4 +1,5 @@ from random import randint +from typing import List import pytest from flask import Flask, jsonify, request @@ -168,6 +169,16 @@ def post(self): return {} +class ReturnListView(MethodView): + @api.validate( + resp=Response(HTTP_200=List[JSON]), + ) + def get(self): + pre_serialize = bool(int(request.args.get("pre_serialize", default=0))) + data = [JSON(name="user1", limit=1), JSON(name="user2", limit=2)] + return [entry.dict() if pre_serialize else entry for entry in data] + + app.add_url_rule("/ping", view_func=Ping.as_view("ping")) app.add_url_rule("/api/user/", view_func=User.as_view("user"), methods=["POST"]) app.add_url_rule( @@ -202,6 +213,10 @@ def post(self): "/api/list_json", view_func=ListJsonView.as_view("list_json_view"), ) +app.add_url_rule( + "/api/return_list", + view_func=ReturnListView.as_view("return_list_view"), +) # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` diff --git a/tests/test_plugin_quart.py b/tests/test_plugin_quart.py index 4b57b0cc..40b51783 100644 --- a/tests/test_plugin_quart.py +++ b/tests/test_plugin_quart.py @@ -1,4 +1,5 @@ from random import randint +from typing import List import pytest from quart import Quart, jsonify, request @@ -150,6 +151,14 @@ async def list_json(): return {} +@app.route("/api/return_list") +@api.validate(resp=Response(HTTP_200=List[JSON])) +def return_list(): + pre_serialize = bool(int(request.args.get("pre_serialize", default=0))) + data = [JSON(name="user1", limit=1), JSON(name="user2", limit=2)] + return [entry.dict() if pre_serialize else entry for entry in data] + + # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` # without app_context. diff --git a/tests/test_plugin_starlette.py b/tests/test_plugin_starlette.py index ad71313c..9fd20fd0 100644 --- a/tests/test_plugin_starlette.py +++ b/tests/test_plugin_starlette.py @@ -1,5 +1,6 @@ import io from random import randint +from typing import List import pytest from starlette.applications import Starlette @@ -142,6 +143,15 @@ async def list_json(request): return JSONResponse({}) +@api.validate(resp=Response(HTTP_200=List[JSON])) +async def return_list(request): + pre_serialize = bool(int(request.query_params.get("pre_serialize", 0))) + data = [JSON(name="user1", limit=1), JSON(name="user2", limit=2)] + return PydanticResponse( + [entry.dict() if pre_serialize else entry for entry in data] + ) + + app = Starlette( routes=[ Route("/ping", Ping), @@ -175,6 +185,7 @@ async def list_json(request): Route("/no_response", no_response, methods=["POST", "GET"]), Route("/file_upload", file_upload, methods=["POST"]), Route("/list_json", list_json, methods=["POST"]), + Route("/return_list", return_list, methods=["GET"]), ], ), Mount("/static", app=StaticFiles(directory="docs"), name="static"), @@ -374,6 +385,16 @@ def test_json_list_request(client): assert resp.status_code == 200, resp.text +@pytest.mark.parametrize("pre_serialize", [False, True]) +def test_return_list_request(client, pre_serialize: bool): + resp = client.get(f"/api/return_list?pre_serialize={int(pre_serialize)}") + assert resp.status_code == 200 + assert resp.json() == [ + {"name": "user1", "limit": 1}, + {"name": "user2", "limit": 2}, + ] + + def test_starlette_upload_file(client): file_content = "abcdef" file_io = io.BytesIO(file_content.encode("utf-8")) diff --git a/tests/test_response.py b/tests/test_response.py index daa24cda..1eca99ea 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -111,6 +111,8 @@ def test_list_model(): resp = Response(HTTP_200=List[JSON]) model = resp.find_model(200) expect_model = gen_list_model(JSON) + assert resp.expect_list_result(200) + assert not resp.expect_list_result(500) assert get_type_hints(model) == get_type_hints(expect_model) assert type(model) is type(expect_model) assert issubclass(model, BaseModel) From bae8ff9efcab6689cbc0c33209fbaf98dfa1a4ed Mon Sep 17 00:00:00 2001 From: Jean-Edouard BOULANGER Date: Wed, 9 Aug 2023 08:40:12 +0200 Subject: [PATCH 2/3] Skip validation if all list entries have the expected type --- Makefile | 2 +- spectree/plugins/falcon_plugin.py | 3 +++ spectree/plugins/flask_plugin.py | 3 +++ spectree/plugins/quart_plugin.py | 3 +++ spectree/response.py | 20 ++++++++++++++------ 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index ac5f5f09..ab74608c 100644 --- a/Makefile +++ b/Makefile @@ -45,7 +45,7 @@ format: lint: isort --check --diff --project=spectree ${SOURCE_FILES} black --check --diff ${SOURCE_FILES} - flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=16 + flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=17 mypy --install-types --non-interactive ${MYPY_SOURCE_FILES} .PHONY: test doc diff --git a/spectree/plugins/falcon_plugin.py b/spectree/plugins/falcon_plugin.py index 804219c7..cb13a432 100644 --- a/spectree/plugins/falcon_plugin.py +++ b/spectree/plugins/falcon_plugin.py @@ -231,6 +231,9 @@ def validate( status = int(_resp.status[:3]) expect_model = resp.find_model(status) if resp.expect_list_result(status) and isinstance(model, list): + expected_list_item_type = resp.get_expected_list_item_type(status) + if all(isinstance(entry, expected_list_item_type) for entry in model): + skip_validation = True _resp.media = [ (entry.dict() if isinstance(entry, BaseModel) else entry) for entry in model diff --git a/spectree/plugins/flask_plugin.py b/spectree/plugins/flask_plugin.py index 8f3b5937..5da90289 100644 --- a/spectree/plugins/flask_plugin.py +++ b/spectree/plugins/flask_plugin.py @@ -211,6 +211,9 @@ def validate( if resp: expect_model = resp.find_model(status) if resp.expect_list_result(status) and isinstance(model, list): + expected_list_item_type = resp.get_expected_list_item_type(status) + if all(isinstance(entry, expected_list_item_type) for entry in model): + skip_validation = True result = ( [ (entry.dict() if isinstance(entry, BaseModel) else entry) diff --git a/spectree/plugins/quart_plugin.py b/spectree/plugins/quart_plugin.py index 5dc5d990..0235d3e1 100644 --- a/spectree/plugins/quart_plugin.py +++ b/spectree/plugins/quart_plugin.py @@ -223,6 +223,9 @@ async def validate( if resp: expect_model = resp.find_model(status) if resp.expect_list_result(status) and isinstance(model, list): + expected_list_item_type = resp.get_expected_list_item_type(status) + if all(isinstance(entry, expected_list_item_type) for entry in model): + skip_validation = True result = ( [ (entry.dict() if isinstance(entry, BaseModel) else entry) diff --git a/spectree/response.py b/spectree/response.py index 71412831..022ef3e6 100644 --- a/spectree/response.py +++ b/spectree/response.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union from ._pydantic import BaseModel from ._types import BaseModelSubclassType, ModelType, NamingStrategy, OptionalModelType @@ -63,7 +63,7 @@ def __init__( self.code_models: Dict[str, ModelType] = {} self.code_descriptions: Dict[str, Optional[str]] = {} - self.codes_expecting_list_result: Set[str] = set() + self.code_list_item_types: Dict[str, ModelType] = {} for code, model_and_description in code_models.items(): assert code in DEFAULT_CODE_DESC, "invalid HTTP status code" description: Optional[str] = None @@ -81,8 +81,9 @@ def __init__( origin_type = getattr(model, "__origin__", None) if origin_type is list or origin_type is List: # type is List[BaseModel] - model = gen_list_model(getattr(model, "__args__")[0]) - self.codes_expecting_list_result.add(code) + list_item_type = getattr(model, "__args__")[0] + model = gen_list_model(list_item_type) + self.code_list_item_types[code] = list_item_type assert issubclass(model, BaseModel), "invalid `pydantic.BaseModel`" assert description is None or isinstance( description, str @@ -132,9 +133,16 @@ def find_model(self, code: int) -> OptionalModelType: def expect_list_result(self, code: int) -> bool: """Check whether a specific HTTP code expects a list result. - :param code: Status code string, format('HTTP_[0-9]_{3}'), 'HTTP_200'. + :param code: Status code (example: 200) + """ + return f"HTTP_{code}" in self.code_list_item_types + + def get_expected_list_item_type(self, code: int) -> ModelType: + """Get the expected list result item type. + + :param code: Status code (example: 200) """ - return f"HTTP_{code}" in self.codes_expecting_list_result + return self.code_list_item_types[f"HTTP_{code}"] def get_code_description(self, code: str) -> str: """Get the description of the given status code. From a66ff11ed930f860055727e38fea0c0060198a5e Mon Sep 17 00:00:00 2001 From: Jean-Edouard BOULANGER Date: Wed, 16 Aug 2023 09:07:06 +0200 Subject: [PATCH 3/3] Implement common solution for Falcon sync/async response validation --- spectree/plugins/falcon_plugin.py | 92 ++++++++++++++++--------------- spectree/plugins/quart_plugin.py | 2 +- tests/test_plugin_falcon_asgi.py | 10 +++- 3 files changed, 57 insertions(+), 47 deletions(-) diff --git a/spectree/plugins/falcon_plugin.py b/spectree/plugins/falcon_plugin.py index cb13a432..c6163508 100644 --- a/spectree/plugins/falcon_plugin.py +++ b/spectree/plugins/falcon_plugin.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, get_type_hints from falcon import HTTP_400, HTTP_415, HTTPError +from falcon import Response as FalconResponse from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN from .._pydantic import BaseModel, ValidationError @@ -188,6 +189,34 @@ def request_validation(self, req, query, json, form, headers, cookies): req_form = {x.name: x.stream.read() for x in req.get_media()} req.context.form = form.parse_obj(req_form) + def response_validation( + self, + response_spec: Optional[Response], + falcon_response: FalconResponse, + skip_validation: bool, + ) -> None: + if response_spec and response_spec.has_model(): + model = falcon_response.media + status = int(falcon_response.status[:3]) + expect_model = response_spec.find_model(status) + if response_spec.expect_list_result(status) and isinstance(model, list): + expected_list_item_type = response_spec.get_expected_list_item_type( + status + ) + if all(isinstance(entry, expected_list_item_type) for entry in model): + skip_validation = True + falcon_response.media = [ + (entry.dict() if isinstance(entry, BaseModel) else entry) + for entry in model + ] + elif expect_model and isinstance(falcon_response.media, expect_model): + falcon_response.media = model.dict() + skip_validation = True + if self._data_set_manually(falcon_response): + skip_validation = True + if expect_model and not skip_validation: + expect_model.parse_obj(falcon_response.media) + def validate( self, func: Callable, @@ -226,32 +255,16 @@ def validate( func(*args, **kwargs) - if resp and resp.has_model(): - model = _resp.media - status = int(_resp.status[:3]) - expect_model = resp.find_model(status) - if resp.expect_list_result(status) and isinstance(model, list): - expected_list_item_type = resp.get_expected_list_item_type(status) - if all(isinstance(entry, expected_list_item_type) for entry in model): - skip_validation = True - _resp.media = [ - (entry.dict() if isinstance(entry, BaseModel) else entry) - for entry in model - ] - elif expect_model and isinstance(_resp.media, expect_model): - _resp.media = model.dict() - skip_validation = True - - if self._data_set_manually(_resp): - skip_validation = True - - if expect_model and not skip_validation: - try: - expect_model.parse_obj(_resp.media) - except ValidationError as err: - resp_validation_error = err - _resp.status = HTTP_500 - _resp.media = err.errors() + try: + self.response_validation( + response_spec=resp, + falcon_response=_resp, + skip_validation=skip_validation, + ) + except ValidationError as err: + resp_validation_error = err + _resp.status = HTTP_500 + _resp.media = err.errors() after(_req, _resp, resp_validation_error, _self) @@ -338,22 +351,15 @@ async def validate( await func(*args, **kwargs) - if resp and resp.has_model(): - model = resp.find_model(_resp.status[:3]) - if model and isinstance(_resp.media, model): - _resp.media = _resp.media.dict() - skip_validation = True - - if self._data_set_manually(_resp): - skip_validation = True - - model = resp.find_model(_resp.status[:3]) - if model and not skip_validation: - try: - model.parse_obj(_resp.media) - except ValidationError as err: - resp_validation_error = err - _resp.status = HTTP_500 - _resp.media = err.errors() + try: + self.response_validation( + response_spec=resp, + falcon_response=_resp, + skip_validation=skip_validation, + ) + except ValidationError as err: + resp_validation_error = err + _resp.status = HTTP_500 + _resp.media = err.errors() after(_req, _resp, resp_validation_error, _self) diff --git a/spectree/plugins/quart_plugin.py b/spectree/plugins/quart_plugin.py index 0235d3e1..4894e974 100644 --- a/spectree/plugins/quart_plugin.py +++ b/spectree/plugins/quart_plugin.py @@ -234,7 +234,7 @@ async def validate( status, *rest, ) - if expect_model and isinstance(model, expect_model): + elif expect_model and isinstance(model, expect_model): skip_validation = True result = (model.dict(), status, *rest) diff --git a/tests/test_plugin_falcon_asgi.py b/tests/test_plugin_falcon_asgi.py index 7a51713a..57689db4 100644 --- a/tests/test_plugin_falcon_asgi.py +++ b/tests/test_plugin_falcon_asgi.py @@ -127,8 +127,9 @@ class ReturnListView: @api.validate(resp=Response(HTTP_200=List[JSON])) async def on_get(self, req, resp): + pre_serialize = bool(int(req.params.get("pre_serialize", 0))) data = [JSON(name="user1", limit=1), JSON(name="user2", limit=2)] - resp.media = [entry.dict() for entry in data] + resp.media = [entry.dict() if pre_serialize else entry for entry in data] class FileUploadView: @@ -200,8 +201,11 @@ def test_falcon_list_json_request_async(client): assert resp.status_code == 200 -def test_falcon_return_list_request_async(client): - resp = client.simulate_request("GET", "/api/return_list") +@pytest.mark.parametrize("pre_serialize", [False, True]) +def test_falcon_return_list_request_async(client, pre_serialize: bool): + resp = client.simulate_request( + "GET", f"/api/return_list?pre_serialize={int(pre_serialize)}" + ) assert resp.status_code == 200 assert resp.json == [ {"name": "user1", "limit": 1},