diff --git a/aiohttp_swagger/helpers/validation.py b/aiohttp_swagger/helpers/validation.py index 9b647d8..b2e5d17 100644 --- a/aiohttp_swagger/helpers/validation.py +++ b/aiohttp_swagger/helpers/validation.py @@ -21,13 +21,13 @@ Response, json_response, ) -from collections import defaultdict +from collections import defaultdict, MutableMapping from jsonschema import ( - validate, ValidationError, FormatChecker, + Draft4Validator, + validators, ) -from jsonschema.validators import validator_for __all__ = ( @@ -54,12 +54,35 @@ def multi_dict_to_dict(mld: Mapping) -> Mapping: } -def validate_schema(obj: Mapping, schema: Mapping): - validate(obj, schema, format_checker=FormatChecker()) +def extend_with_default(validator_class): + validate_properties = validator_class.VALIDATORS["properties"] -def validate_multi_dict(obj, schema): - validate(multi_dict_to_dict(obj), schema, format_checker=FormatChecker()) + def set_defaults(validator, properties, instance, schema): + if isinstance(instance, MutableMapping): + for prop, sub_schema in properties.items(): + if "default" in sub_schema: + instance.setdefault(prop, sub_schema["default"]) + for error in validate_properties( + validator, properties, instance, schema): + yield error + + return validators.extend(validator_class, {"properties": set_defaults}) + + +json_schema_validator = extend_with_default(Draft4Validator) + + +def validate_schema(obj: Mapping, schema: Mapping) -> Mapping: + json_schema_validator(schema, format_checker=FormatChecker()).validate(obj) + return obj + + +def validate_multi_dict(obj, schema) -> Mapping: + _obj = multi_dict_to_dict(obj) + json_schema_validator( + schema, format_checker=FormatChecker()).validate(_obj) + return _obj def validate_content_type(swagger: Mapping, content_type: str): @@ -73,34 +96,34 @@ async def validate_request( request: Request, parameter_groups: Mapping, swagger: Mapping): + res = {} validate_content_type(swagger, request.content_type) - for group_name, group_schemas in parameter_groups.items(): + for group_name, group_schema in parameter_groups.items(): if group_name == 'header': - headers = request.headers - for schema in group_schemas: - validate_multi_dict(headers, schema) + res['headers'] = validate_multi_dict(request.headers, group_schema) if group_name == 'query': - query = request.query - for schema in group_schemas: - validate_multi_dict(query, schema) + res['query'] = validate_multi_dict(request.query, group_schema) if group_name == 'formData': try: data = await request.post() except ValueError: data = None - for schema in group_schemas: - validate_multi_dict(data, schema) + res['formData'] = validate_multi_dict(data, group_schema) if group_name == 'body': - try: - content = await request.json() - except json.JSONDecodeError: - content = None - for schema in group_schemas: - validate_schema(content, schema) + if request.content_type == 'application/json': + try: + content = await request.json() + except json.JSONDecodeError: + content = None + elif request.content_type.startswith('text'): + content = await request.text() + else: + content = await request.read() + res['body'] = validate_schema(content, group_schema) if group_name == 'path': params = dict(request.match_info) - for schema in group_schemas: - validate_schema(params, schema) + res['path'] = validate_schema(params, group_schema) + return res def adjust_swagger_item_to_json_schemes(*schemes: Mapping) -> Mapping: @@ -124,7 +147,7 @@ def adjust_swagger_item_to_json_schemes(*schemes: Mapping) -> Mapping: required_fields.append(name) if required_fields: new_schema['required'] = required_fields - validator_for(new_schema).check_schema(new_schema) + validators.validator_for(new_schema).check_schema(new_schema) return new_schema @@ -139,21 +162,21 @@ def adjust_swagger_body_item_to_json_schema(schema: Mapping) -> Mapping: new_schema, ] } - validator_for(new_schema).check_schema(new_schema) + validators.validator_for(new_schema).check_schema(new_schema) return new_schema def adjust_swagger_to_json_schema(parameter_groups: Iterable) -> Mapping: - res = defaultdict(list) + res = {} for group_name, group_schemas in parameter_groups: if group_name in ('query', 'header', 'path', 'formData'): json_schema = adjust_swagger_item_to_json_schemes(*group_schemas) - res[group_name].append(json_schema) + res[group_name] = json_schema else: # only one possible schema for in: body schema = list(group_schemas)[0] json_schema = adjust_swagger_body_item_to_json_schema(schema) - res[group_name].append(json_schema) + res[group_name] = json_schema return res @@ -189,7 +212,9 @@ def get_ref(ref: str): res = {} for key, value in schema.items(): if key == '$ref': - res.update(get_ref(value)) + ref_data = get_ref(value) + ref_data_resolved = dereference_schema(swagger, ref_data) + res.update(ref_data_resolved) else: res[key] = dereference_schema(swagger, value) return res @@ -216,7 +241,9 @@ async def _wrapper(*args, **kwargs) -> Response: request = args[0].request \ if isinstance(args[0], web.View) else args[0] try: - await validate_request(request, parameter_groups, schema) + validation = \ + await validate_request(request, parameter_groups, schema) + request.validation = validation except ValidationError as exc: logger.exception(exc) exc_dict = validation_exc_to_dict(exc) diff --git a/doc/source/customizing.rst b/doc/source/customizing.rst index ae50cf9..7cb5249 100644 --- a/doc/source/customizing.rst +++ b/doc/source/customizing.rst @@ -212,6 +212,17 @@ Global Swagger YAML :samp:`aiohttp-swagger` also allow to validate swagger schema against json schema: +Validated object would be added as **request.validation**. Default values also will be filled into object. + +.. code-block:: javascript + + { + 'query': {}, // validated request.query + 'path': {}, // validated request.path + 'body': {}, // validated request.json() + 'formData': {}, // validated post request.data() + 'headers': {}, // validated post request.headers + } .. code-block:: python diff --git a/tests/test_validation_body.py b/tests/test_validation_body.py new file mode 100644 index 0000000..4030594 --- /dev/null +++ b/tests/test_validation_body.py @@ -0,0 +1,170 @@ +import asyncio +import json + +import pytest +from aiohttp import web +from aiohttp_swagger import * + + +@asyncio.coroutine +@swagger_validation +def post1(request, *args, **kwargs): + """ + --- + description: Post resources + tags: + - Function View + produces: + - application/json + consumes: + - application/json + parameters: + - in: body + name: body + required: true + schema: + type: object + properties: + test: + type: string + default: default + minLength: 2 + test1: + type: string + default: default1 + minLength: 2 + responses: + "200": + description: successful operation. + "405": + description: invalid HTTP Method + """ + return web.json_response(data=request.validation['body']) + + +@asyncio.coroutine +@swagger_validation +def post2(request, *args, **kwargs): + """ + --- + description: Post resources + tags: + - Function View + produces: + - text/plain + consumes: + - text/plain + parameters: + - in: body + name: body + required: true + schema: + type: string + default: default + minLength: 2 + responses: + "200": + description: successful operation. + "405": + description: invalid HTTP Method + """ + return web.Response(text=request.validation['body']) + + +POST1_METHOD_PARAMETERS = [ + # success + ( + 'post', + '/example12', + {'test': 'default'}, + {'Content-Type': 'application/json'}, + 200 + ), + # success + ( + 'post', + '/example12', + {}, + {'Content-Type': 'application/json'}, + 200 + ), + # error + ( + 'post', + '/example12', + None, + {'Content-Type': 'application/json'}, + 400 + ), +] + +POST2_METHOD_PARAMETERS = [ + # success + ( + 'post', + '/example12', + '1234', + {'Content-Type': 'text/plain'}, + 200 + ), + ( + 'post', + '/example12', + None, + {'Content-Type': 'text/plain'}, + 400 + ), +] + + +@pytest.mark.parametrize("method,url,body,headers,response", + POST1_METHOD_PARAMETERS) +@asyncio.coroutine +def test_function_post1_method_body_validation( + test_client, loop, swagger_file, method, url, body, headers, response): + app = web.Application(loop=loop) + app.router.add_post("/example12", post1) + setup_swagger( + app, + swagger_merge_with_file=True, + swagger_validate_schema=True, + swagger_from_file=swagger_file, + ) + client = yield from test_client(app) + data = json.dumps(body) \ + if headers['Content-Type'] == 'application/json' else body + resp = yield from getattr(client, method)(url, data=data, headers=headers) + text = yield from resp.json() + assert resp.status == response, text + if response != 200: + assert 'error' in text + else: + assert 'error' not in text + assert 'test' in text + assert text['test'] == 'default' + assert text['test1'] == 'default1' + + +@pytest.mark.parametrize("method,url,body,headers,response", + POST2_METHOD_PARAMETERS) +@asyncio.coroutine +def test_function_post2_method_body_validation( + test_client, loop, swagger_file, method, url, body, headers, response): + app = web.Application(loop=loop) + app.router.add_post("/example12", post2) + setup_swagger( + app, + swagger_merge_with_file=True, + swagger_validate_schema=True, + swagger_from_file=swagger_file, + ) + client = yield from test_client(app) + data = json.dumps(body) \ + if headers['Content-Type'] == 'application/json' else body + resp = yield from getattr(client, method)(url, data=data, headers=headers) + text = yield from resp.text() + assert resp.status == response, text + if response != 200: + assert 'error' in text + else: + assert isinstance(text, str) diff --git a/tests/test_validation_defaults.py b/tests/test_validation_defaults.py new file mode 100644 index 0000000..00659b6 --- /dev/null +++ b/tests/test_validation_defaults.py @@ -0,0 +1,83 @@ +import asyncio + +import pytest +from aiohttp import web +from aiohttp_swagger import * + + +@asyncio.coroutine +@swagger_validation +def post(request, *args, **kwargs): + """ + --- + description: Post User data + tags: + - Function View + produces: + - application/json + consumes: + - application/json + parameters: + - in: query + name: test + type: string + minLength: 3 + required: true + default: test + - in: query + name: test1 + type: string + minLength: 3 + required: true + default: test1 + responses: + "200": + description: successful operation. + "405": + description: invalid HTTP Method + """ + return web.json_response(data=request.validation) + + +METHOD_PARAMETERS = [ + # too short test + ( + 'post', + '/example2?test=1', + {'Content-Type': 'application/json'}, + 400 + ), + # without test + ( + 'post', + '/example2', + {'Content-Type': 'application/json'}, + 200 + ), +] + + +@pytest.mark.parametrize("method,url,headers,response", METHOD_PARAMETERS) +@asyncio.coroutine +def test_function_post_with_defaults( + test_client, loop, swagger_ref_file, + method, url, headers, response): + app = web.Application(loop=loop) + app.router.add_post("/example2", post) + setup_swagger( + app, + swagger_merge_with_file=True, + swagger_validate_schema=True, + swagger_from_file=swagger_ref_file, + ) + client = yield from test_client(app) + resp = yield from getattr(client, method)(url, headers=headers) + data = yield from resp.json() + assert resp.status == response, data + if response != 200: + assert 'error' in data + else: + assert 'error' not in data + # both default parameters + assert data['query']['test1'] == 'test1' + assert data['query']['test'] == 'test'