-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
97 additions
and
197 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,216 +1,116 @@ | ||
import json | ||
|
||
import marshmallow as ma | ||
from flask import Flask | ||
from marshmallow import Schema | ||
from webargs import fields | ||
from webargs.flaskparser import use_args | ||
from flask import jsonify | ||
from pytest import raises | ||
|
||
from flask_classful import FlaskView | ||
from flask_classful import route | ||
|
||
# we'll make a list to hold some quotes for our app | ||
quotes = [ | ||
"A noble spirit embiggens the smallest man! ~ Jebediah Springfield", | ||
"If there is a way to do it better... find it. ~ Thomas Edison", | ||
"No one knows what he can do till he tries. ~ Publilius Syrus", | ||
] | ||
|
||
app = Flask(__name__) | ||
app.config["DEBUG"] = True | ||
|
||
put_args = {"text": fields.Str(required=True)} | ||
|
||
|
||
class UserSchema(Schema): | ||
email = ma.fields.Str() | ||
|
||
class Meta: | ||
strict = True | ||
|
||
|
||
def make_user_schema(request): | ||
# Filter based on 'fields' query parameter | ||
only = request.args.get("fields", None) | ||
# Respect partial updates for PATCH requests | ||
partial = request.method == "PATCH" | ||
# Add current request to the schema's context | ||
return UserSchema(only=only, partial=partial, context={"request": request}) | ||
|
||
|
||
class UsersView(FlaskView): | ||
base_args = ["args"] | ||
|
||
@use_args(make_user_schema) | ||
def post(self, args): | ||
return args["email"] | ||
|
||
@use_args(make_user_schema) | ||
def put(self, args, id): | ||
return args["email"] | ||
|
||
@use_args(make_user_schema) | ||
def patch(self, args, id): | ||
return args["email"] | ||
|
||
|
||
class QuoteSchema(ma.Schema): | ||
id = ma.fields.Int() | ||
text = ma.fields.Str() | ||
|
||
class Meta: | ||
strict = True | ||
|
||
|
||
def make_quote_schema(request): | ||
# Filter based on 'fields' query parameter | ||
only = request.args.get("fields", None) | ||
# Respect partial updates for PATCH requests | ||
partial = request.method == "PATCH" | ||
# Add current request to the schema's context | ||
return QuoteSchema(only=only, partial=partial, context={"request": request}) | ||
|
||
|
||
class QuotesView(FlaskView): | ||
base_args = ["args"] | ||
|
||
def index(self): | ||
return "<br>".join(quotes) | ||
|
||
def get(self, id): | ||
quote_id = int(id) | ||
if quote_id < len(quotes) - 1: | ||
return quotes[quote_id] | ||
else: | ||
return "Not Found", 404 | ||
|
||
@use_args(put_args) | ||
def put(self, args, id): | ||
quote_id = int(id) | ||
if quote_id >= len(quotes) - 1: | ||
return "Not Found", 404 | ||
quotes[quote_id] = args["text"] | ||
return quotes[quote_id] | ||
|
||
@route("<id>/", methods=["PATCH"]) | ||
@use_args(make_quote_schema) | ||
def factory(self, args, id): | ||
quote_id = int(id) | ||
if quote_id >= len(quotes) - 1: | ||
return "Not Found", 404 | ||
quotes[quote_id] = args["text"] | ||
return quotes[quote_id] | ||
|
||
|
||
class UglyNameView(FlaskView): | ||
base_args = ["args"] | ||
route_base = "quotes-2" | ||
|
||
def index(self): | ||
return "<br>".join(quotes) | ||
|
||
def get(self, id): | ||
quote_id = int(id) | ||
if quote_id < len(quotes) - 1: | ||
return quotes[quote_id] | ||
else: | ||
return "Not Found", 404 | ||
|
||
@use_args(put_args) | ||
def put(self, args, id): | ||
quote_id = int(id) | ||
if quote_id >= len(quotes) - 1: | ||
return "Not Found", 404 | ||
quotes[quote_id] = args["text"] | ||
return quotes[quote_id] | ||
|
||
|
||
QuotesView.register(app) | ||
UglyNameView.register(app) | ||
UsersView.register(app) | ||
|
||
client = app.test_client() | ||
|
||
input_headers = [("Content-Type", "application/json")] | ||
input_data = {"text": "My quote"} | ||
|
||
|
||
def test_users_post(): | ||
resp = client.post( | ||
"users/", headers=input_headers, data=json.dumps({"email": "[email protected]"}) | ||
) | ||
class NoRouteBaseArgsView(FlaskView): | ||
route_base = "/route/without/args" | ||
|
||
def get(self, arg_1): | ||
return ( | ||
jsonify( | ||
{ | ||
"arg_1": arg_1, | ||
} | ||
), | ||
200, | ||
) | ||
|
||
|
||
class MultiRouteBaseArgsView(FlaskView): | ||
route_base = "/route/<arg_1>/with/<arg_2>/some_args" | ||
|
||
def get(self, arg_1, arg_2, arg_3): | ||
return ( | ||
jsonify( | ||
{ | ||
"arg_1": arg_1, | ||
"arg_2": arg_2, | ||
"arg_3": arg_3, | ||
} | ||
), | ||
200, | ||
) | ||
|
||
|
||
class OtherRouteBaseArgsView(FlaskView): | ||
route_base = "/route/<arg_1>/other" | ||
|
||
def get(self, arg_1, arg_2): | ||
return ( | ||
jsonify( | ||
{ | ||
"arg_1": arg_1, | ||
"arg_2": arg_2, | ||
} | ||
), | ||
200, | ||
) | ||
|
||
|
||
class ErroneousRouteBaseArgsView(FlaskView): | ||
route_base = "/route/<arg_1>/error" | ||
|
||
def get(self, arg_2): | ||
return ( | ||
jsonify( | ||
{ | ||
"arg_2": arg_2, | ||
} | ||
), | ||
200, | ||
) | ||
|
||
|
||
NoRouteBaseArgsView.register(app) | ||
MultiRouteBaseArgsView.register(app) | ||
OtherRouteBaseArgsView.register(app) | ||
ErroneousRouteBaseArgsView.register(app) | ||
|
||
|
||
def test_no_route_args(): | ||
_, base_args = NoRouteBaseArgsView.get_route_base() | ||
# No route base with args == no base args | ||
assert base_args == set() | ||
client = app.test_client() | ||
resp = client.get("/route/without/args/foo/") | ||
assert resp.status_code == 200 | ||
assert "[email protected]" == resp.data.decode("ascii") | ||
assert resp.json == {"arg_1": "foo"} | ||
|
||
|
||
def test_users_put(): | ||
resp = client.put( | ||
"users/1/", | ||
headers=input_headers, | ||
data=json.dumps({"email": "[email protected]"}), | ||
) | ||
assert resp.status_code == 200 | ||
assert "[email protected]" == resp.data.decode("ascii") | ||
def test_route_args_are_detected(): | ||
_, base_args = MultiRouteBaseArgsView.get_route_base() | ||
assert base_args == {"arg_1", "arg_2"} | ||
|
||
|
||
def test_users_patch(): | ||
resp = client.patch( | ||
"users/1/", | ||
headers=input_headers, | ||
data=json.dumps({"email": "[email protected]"}), | ||
) | ||
def test_multi_route_args_values(): | ||
client = app.test_client() | ||
resp = client.get("/route/foo/with/bar/some_args/baz/") | ||
assert resp.status_code == 200 | ||
assert "[email protected]" == resp.data.decode("ascii") | ||
|
||
assert resp.json == {"arg_1": "foo", "arg_2": "bar", "arg_3": "baz"} | ||
|
||
def test_quotes_index(): | ||
resp = client.get("/quotes/") | ||
num = len(str(resp.data).split("<br>")) | ||
assert 3 == num | ||
resp = client.get("/quotes") | ||
assert resp.status_code == 308 | ||
|
||
def test_route_args_are_independent_across_views(): | ||
_, base_args = OtherRouteBaseArgsView.get_route_base() | ||
# arg_2 does not leak from evaluating the previous view | ||
assert base_args == {"arg_1"} | ||
|
||
def test_quotes_get(): | ||
resp = client.get("/quotes/0/") | ||
assert quotes[0] == resp.data.decode("ascii") | ||
|
||
|
||
def test_quotes_put(): | ||
resp = client.put("/quotes/1/", headers=input_headers, data=json.dumps(input_data)) | ||
assert input_data["text"] == resp.data.decode("ascii") | ||
|
||
|
||
def test_quotes_factory(): | ||
resp = client.patch( | ||
"/quotes/1/", headers=input_headers, data=json.dumps(input_data) | ||
) | ||
assert input_data["text"] == resp.data.decode("ascii") | ||
|
||
|
||
def test_quotes2_index(): | ||
resp = client.get("/quotes-2/") | ||
num = len(str(resp.data).split("<br>")) | ||
assert 3 == num | ||
resp = client.get("/quotes-2") | ||
assert resp.status_code == 308 | ||
|
||
|
||
def test_quotes2_get(): | ||
resp = client.get("/quotes-2/0/") | ||
assert quotes[0] == resp.data.decode("ascii") | ||
assert UglyNameView.base_args.count(UglyNameView.route_base) == 0 | ||
|
||
|
||
def test_quotes2_put(): | ||
resp = client.put( | ||
"/quotes-2/1/", headers=input_headers, data=json.dumps(input_data) | ||
def test_missing_base_arg_in_method(): | ||
_, base_args = ErroneousRouteBaseArgsView.get_route_base() | ||
# Base arg is recognized | ||
assert base_args == {"arg_1"} | ||
# Rule is correctly generated | ||
assert ( | ||
ErroneousRouteBaseArgsView.build_rule("/", ErroneousRouteBaseArgsView.get) | ||
== ErroneousRouteBaseArgsView.route_base + "/<arg_2>" | ||
) | ||
assert input_data["text"] == resp.data.decode("ascii") | ||
assert UglyNameView.base_args.count(UglyNameView.route_base) == 0 | ||
|
||
|
||
# see: https://github.com/pallets-eco/flask-classful/pull/56#issuecomment-328985183 | ||
def test_unique_elements(): | ||
client.put("/quotes-2/1/", headers=input_headers, data=json.dumps(input_data)) | ||
assert UglyNameView.base_args.count(UglyNameView.route_base) == 0 | ||
# But calling the method fails because ErroneousRouteBaseArgsView.get is supplied with an unexpected "arg_1" argument | ||
client = app.test_client() | ||
with raises(TypeError): | ||
client.get("/route/foo/error/baz/") |