diff --git a/src/flask_container_scaffold/util.py b/src/flask_container_scaffold/util.py index dcf7326..6322673 100644 --- a/src/flask_container_scaffold/util.py +++ b/src/flask_container_scaffold/util.py @@ -1,5 +1,4 @@ import configparser -import json from flask import request from pydantic import ValidationError @@ -68,20 +67,58 @@ def parse_input(logger, obj, default_return=BaseApiView): :returns: Instantiated object of type obj on success, or default_return on failure to parse. """ + args_dict = _preprocess_request() try: - if request.is_json: - parsed_args = obj.model_validate_json(json.dumps(request.json)) - else: - if request.args: - args = request.args - else: - args = request.form - parsed_args = obj.model_validate_json(json.dumps(args.to_dict())) + # Unpack the dict into keyword arguments + parsed_args = obj(**args_dict) except ValidationError as e: logger.error(f"Validation error is: {e}") errors_result = {} - errors_message = f"Errors detected: {e.error_count()}" + errors_message = f"Errors detected: {len(e.errors())}" for error in e.errors(): errors_result[error.get("loc")[0]] = error.get("msg") parsed_args = default_return(msg=errors_message, errors=errors_result) return parsed_args + + +def _preprocess_request() -> dict: + """ + Checks the various places in a request that could contain parameters, and + extracts them into a dictionary that can then be used for further parsing. + This dictionary should contain no duplicates, and chooses what to use based + on the following rules: + + 1. If a request contains both parameters embedded in the url (like + endpoint/1) AND + * json data, they will be preferred in this order: + - url-embedded + - json data + * query string and/or form data, they will be preferred in this order: + - url-embedded + - query string + - form data + 2. If a request contains both json data AND either query strings or form + data, only the json will be parsed. However, Flask currently prevents this + from happening, as it will not allow a user to pass both types of data at + the same time. The `curl` command is similarly mutually exclusive. + + :returns: dict containing the parsed parameters + """ + processed_request = {} + # Check for URL Parameters + if request.view_args: + processed_request = request.view_args + # If the request has json input, parse that and combine with URL + # parameters, preferring the latter. + if request.is_json: + processed_request = {**request.json, **processed_request} + else: + # Check in query-string for additional parameters. Prefer any that were + # previously found in URL parameters. + if request.args: + processed_request = {**request.args.to_dict(), **processed_request} + # Check in form for additional parameters. Prefer any that were previously + # found in URL parameters or query-string. + if request.form: + processed_request = {**request.form.to_dict(), **processed_request} + return processed_request diff --git a/tests/conftest.py b/tests/conftest.py index 76172ad..2fd0e2d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,11 @@ @pytest.fixture def app(): app = Flask("testapp") + + @app.route('/endpoint/') + def some_endpoint(fake_id): + pass + yield app diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 1baf925..00e5bd4 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -33,6 +33,7 @@ class FakeApiModelExtension(BaseApiView): class FakeModel(FakeApiModelExtension): + fake_id: int = 1 code: int = 0 name: str @@ -43,6 +44,11 @@ class FakeModel2(FakeApiModelExtension): status: str +class ComplexInput(FakeModel): + field1: str + field2: str + + class TestParseInput: def test_no_data(self, app): @@ -78,22 +84,44 @@ def test_no_data_custom_return(self, to_parse, required_attrs, app): assert retval.errors.get(missing_attr) == 'Field required' assert isinstance(retval, BaseApiView) - @pytest.mark.parametrize("input_type,input_val", - [('json', {'name': 'foo'}), - ('qs', 'name=foo'), - ('form', {'name': 'foo'})]) - def test_parses_json(self, input_type, input_val, app): + @pytest.mark.parametrize("input_val", + [{'name': 'foo'}, + {'name': 'foo', 'fake_id': 5}]) + def test_parses_url_params_json(self, input_val, app): """ - GIVEN a request with json, a query string or form data + GIVEN a request with a url parameter (such as endpoint/) WHEN we call parse_input on that request - THEN we get a populated object returned, of the type requested. + THEN we get a populated object returned with properly set + AND any json data appropriately parsed """ - context = {'json': app.test_request_context(json=input_val), - 'qs': app.test_request_context(query_string=input_val), - 'form': app.test_request_context(data=input_val)} - with context.get(input_type): + with app.test_request_context('endpoint/2', json=input_val): retval = parse_input(app.logger, FakeModel) + assert retval.fake_id == 2 assert retval.code == 0 assert retval.errors == {} assert retval.name == 'foo' assert isinstance(retval, FakeModel) + + @pytest.mark.parametrize("input_qs,input_form", + [('field1=foo&fake_id=8', + {'field2': 'foo', 'name': 'bob'}), + ('field1=foo&name=bob', + {'field2': 'foo', 'name': 'tim'}), + ('field1=foo&name=bob&field2=foo', + {})]) + def test_parses_url_params_non_json(self, input_qs, input_form, app): + """" + GIVEN a request with a url parameter (such as endpoint/) + WHEN we call parse_input on that request + THEN we get a populated object returned with properly set + AND any query strings or forms appropriately parsed + """ + + with app.test_request_context('endpoint/2', + query_string=input_qs, data=input_form): + retval = parse_input(app.logger, ComplexInput) + assert retval.fake_id == 2 + assert retval.code == 0 + assert retval.errors == {} + assert retval.name == 'bob' + assert isinstance(retval, ComplexInput)