Skip to content

Commit

Permalink
Redo GraphQL enum fixing logic (#1138)
Browse files Browse the repository at this point in the history
rafalp authored Dec 15, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 64f43cf commit c4cdefb
Showing 8 changed files with 1,327 additions and 476 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# CHANGELOG

## 0.21 (2023-11-08)

- Deprecated `EnumType.bind_to_default_values` method. It will be removed in a future release.
- Added `repair_schema_default_enum_values` to public API.
- Removed `validate_schema_enum_values` and introduced `validate_schema_default_enum_values` in its place. This is a breaking change.


## 0.21 (2023-11-08)

- Added Python 3.12 to tested versions.
12 changes: 6 additions & 6 deletions ariadne/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .enums import (
EnumType,
set_default_enum_values_on_schema,
validate_schema_enum_values,
from .enums import EnumType
from .enums_default_values import (
repair_schema_default_enum_values,
validate_schema_default_enum_values,
)
from .executable_schema import make_executable_schema
from .extensions import ExtensionManager
@@ -70,12 +70,12 @@
"is_default_resolver",
"load_schema_from_path",
"make_executable_schema",
"repair_schema_default_enum_values",
"resolve_to",
"set_default_enum_values_on_schema",
"snake_case_fallback_resolvers",
"subscribe",
"type_implements_interface",
"unwrap_graphql_error",
"upload_scalar",
"validate_schema_enum_values",
"validate_schema_default_enum_values",
]
269 changes: 14 additions & 255 deletions ariadne/enums.py
Original file line number Diff line number Diff line change
@@ -2,49 +2,18 @@
from typing import (
Any,
Dict,
Generator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from functools import reduce, singledispatch
import operator
import warnings

from graphql.type import GraphQLEnumType, GraphQLNamedType, GraphQLSchema
from graphql.language.ast import (
EnumValueNode,
InputValueDefinitionNode,
ObjectValueNode,
)
from graphql.pyutils.undefined import Undefined
from graphql.type.definition import (
GraphQLArgument,
GraphQLField,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLInputType,
GraphQLInterfaceType,
GraphQLNonNull,
GraphQLObjectType,
GraphQLScalarType,
)

from .types import SchemaBindable


T = TypeVar("T")
ArgumentWithKeys = Tuple[str, str, GraphQLArgument, Optional[List["str"]]]
InputFieldWithKeys = Tuple[str, str, GraphQLInputField, Optional[List["str"]]]
GraphQLNamedInputType = Union[
GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType
]


class EnumType(SchemaBindable):
"""Bindable mapping Python values to enumeration members in a GraphQL schema.
@@ -118,34 +87,27 @@ def bind_to_schema(self, schema: GraphQLSchema) -> None:
)
graphql_type.values[key].value = value

def bind_to_default_values(self, schema: GraphQLSchema) -> None:
def bind_to_default_values(self, _schema: GraphQLSchema) -> None:
"""Populates default values of input fields and args in the GraphQL schema.
This step is required because GraphQL query executor doesn't perform a
lookup for default values defined in schema. Instead it simply pulls the
value from fields and arguments `default_value` attribute, which is
`None` by default.
"""
for _, _, arg, key_list in find_enum_values_in_schema(schema):
type_ = resolve_null_type(arg.type)
type_ = cast(GraphQLNamedInputType, type_)

if (
key_list is None
and arg.default_value in self.values
and type_.name == self.name
):
type_ = resolve_null_type(arg.type)
arg.default_value = self.values[arg.default_value]
elif key_list is not None:
enum_value = get_value_from_mapping_value(arg.default_value, key_list)
type_ = cast(GraphQLEnumType, track_type_for_nested(arg, key_list))
> **Deprecated:** Ariadne versions before 0.22 used
`EnumType.bind_to_default_values` method to fix default enum values embedded
in the GraphQL schema. Ariadne 0.22 release introduces universal
`repair_schema_default_enum_values` utility in its place.
"""

if enum_value in self.values and type_.name == self.name:
set_leaf_value_in_mapping(
arg.default_value, key_list, self.values[enum_value]
)
warnings.warn(
(
"'EnumType.bind_to_default_values' was deprecated in Ariadne 0.22 and "
"will be removed in a future release."
),
DeprecationWarning,
)

def validate_graphql_type(self, graphql_type: Optional[GraphQLNamedType]) -> None:
"""Validates that schema's GraphQL type associated with this `EnumType`
@@ -157,206 +119,3 @@ def validate_graphql_type(self, graphql_type: Optional[GraphQLNamedType]) -> Non
"%s is defined in the schema, but it is instance of %s (expected %s)"
% (self.name, type(graphql_type).__name__, GraphQLEnumType.__name__)
)


def set_default_enum_values_on_schema(schema: GraphQLSchema):
"""Sets missing Python values for GraphQL enums in schema.
Recursively scans GraphQL schema for enums and their values. If `value`
attribute is empty, its populated with with a string of its GraphQL name.
This string is then used to represent enum's value in Python instead of `None`.
# Requires arguments
`schema`: a GraphQL schema to set enums default values in.
"""
for type_object in schema.type_map.values():
if isinstance(type_object, GraphQLEnumType):
set_default_enum_values(type_object)


def set_default_enum_values(graphql_type: GraphQLEnumType):
for key in graphql_type.values:
if graphql_type.values[key].value is None:
graphql_type.values[key].value = key


def validate_schema_enum_values(schema: GraphQLSchema) -> None:
"""Raises `ValueError` if GraphQL schema has input fields or arguments with
default values that are undefined enum values.
# Example schema with invalid field argument
This schema fails to validate because argument `role` on field `users`
specifies `REVIEWER` as default value and `REVIEWER` is not a member of
the `UserRole` enum:
```graphql
type Query {
users(role: UserRole = REVIEWER): [User!]!
}
enum UserRole {
MEMBER
MODERATOR
ADMIN
}
type User {
id: ID!
}
```
# Example schema with invalid input field
This schema fails to validate because field `role` on input `UserFilters`
specifies `REVIEWER` as default value and `REVIEWER` is not a member of
the `UserRole` enum:
```graphql
type Query {
users(filter: UserFilters): [User!]!
}
input UserFilters {
name: String
role: UserRole = REVIEWER
}
enum UserRole {
MEMBER
MODERATOR
ADMIN
}
type User {
id: ID!
}
```
"""

for type_name, field_name, arg, _ in find_enum_values_in_schema(schema):
if is_invalid_enum_value(arg):
raise ValueError(
f"Value for type: <{arg.type}> is invalid. "
f"Check InputField/Arguments for <{field_name}> in <{type_name}> "
"(Undefined enum value)."
)


def is_invalid_enum_value(field: Union[GraphQLInputField, GraphQLArgument]) -> bool:
if field.ast_node is None:
return False
return field.default_value is Undefined and field.ast_node.default_value is not None


def find_enum_values_in_schema(
schema: GraphQLSchema,
) -> Generator[Union[ArgumentWithKeys, InputFieldWithKeys], None, None]:
for name, type_ in schema.type_map.items():
result = enum_values_in_types(type_, name)
if result is not None:
yield from result


@singledispatch
def enum_values_in_types(
type_: GraphQLNamedType, # pylint: disable=unused-argument
name: str, # pylint: disable=unused-argument
) -> Optional[Generator[Union[ArgumentWithKeys, InputFieldWithKeys], None, None]]:
pass


@enum_values_in_types.register(GraphQLObjectType)
@enum_values_in_types.register(GraphQLInterfaceType)
def enum_values_in_object_type(
type_: Union[GraphQLObjectType, GraphQLInterfaceType],
field_name: str,
) -> Generator[ArgumentWithKeys, None, None]:
for field in type_.fields.values():
yield from enum_values_in_field_args(field_name, field)


@enum_values_in_types.register(GraphQLInputObjectType)
def enum_values_in_input_type(
type_: GraphQLInputObjectType,
field_name,
) -> Generator[InputFieldWithKeys, None, None]:
yield from _get_field_with_keys(field_name, type_.fields.items())


def enum_values_in_field_args(
field_name: str,
field: GraphQLField,
) -> Generator[ArgumentWithKeys, None, None]:
args = [
(name, arg)
for name, arg in field.args.items()
if isinstance(
arg.type, (GraphQLInputObjectType, GraphQLEnumType, GraphQLNonNull)
)
]

yield from _get_field_with_keys(field_name, args)


def _get_field_with_keys(field_name, fields):
for input_name, field in fields:
resolved_type = resolve_null_type(field.type)
if isinstance(resolved_type, GraphQLEnumType):
yield field_name, input_name, field, None

if isinstance(resolved_type, GraphQLInputObjectType):
if (
field.ast_node is not None
and field.ast_node.default_value is not None
and isinstance(field.ast_node.default_value, ObjectValueNode)
):
routes = get_enum_keys_from_ast(field.ast_node)
for route in routes:
yield field_name, input_name, field, route


def get_enum_keys_from_ast(ast_node: InputValueDefinitionNode) -> List[List["str"]]:
routes = []
object_node = cast(ObjectValueNode, ast_node.default_value)
nodes = [([field.name.value], field) for field in object_node.fields]

while nodes:
key_list, field = nodes.pop()
if isinstance(field.value, EnumValueNode):
routes.append(key_list)

if isinstance(field.value, ObjectValueNode):
for new_field in field.value.fields:
new_route = key_list[:]
new_route.append(new_field.name.value)
nodes.append((new_route, new_field))

return routes


def get_value_from_mapping_value(mapping: Mapping[T, Any], key_list: List[T]) -> Any:
return reduce(operator.getitem, key_list, mapping)


def set_leaf_value_in_mapping(
mapping: Mapping[T, Any], key_list: List[T], value: Any
) -> None:
get_value_from_mapping_value(mapping, key_list[:-1])[key_list[-1]] = value


def track_type_for_nested(
arg: Union[GraphQLArgument, GraphQLInputField], key_list: List[str]
) -> GraphQLInputType:
type_ = resolve_null_type(arg.type)

for elem in key_list:
if isinstance(type_, GraphQLInputObjectType):
type_ = type_.fields[elem].type
return type_


def resolve_null_type(type_: GraphQLInputType) -> GraphQLInputType:
return type_.of_type if isinstance(type_, GraphQLNonNull) else type_
144 changes: 144 additions & 0 deletions ariadne/enums_default_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from graphql import GraphQLInputField, GraphQLSchema

from .enums_values_visitor import (
GraphQLASTEnumDefaultValueLocation,
GraphQLASTEnumsValuesVisitor,
GraphQLSchemaEnumDefaultValueLocation,
GraphQLSchemaEnumsValuesVisitor,
)


__all__ = [
"repair_schema_default_enum_values",
"validate_schema_default_enum_values",
]


def validate_schema_default_enum_values(schema: GraphQLSchema) -> None:
"""Raises `ValueError` if GraphQL schema has input fields or arguments with
default values that are undefined enum values.
# Example schema with invalid field argument
This schema fails to validate because argument `role` on field `users`
specifies `REVIEWER` as default value and `REVIEWER` is not a member of
the `UserRole` enum:
```graphql
type Query {
users(role: UserRole = REVIEWER): [User!]!
}
enum UserRole {
MEMBER
MODERATOR
ADMIN
}
type User {
id: ID!
}
```
# Example schema with invalid input field
This schema fails to validate because field `role` on input `UserFilters`
specifies `REVIEWER` as default value and `REVIEWER` is not a member of
the `UserRole` enum:
```graphql
type Query {
users(filter: UserFilters): [User!]!
}
input UserFilters {
name: String
role: UserRole = REVIEWER
}
enum UserRole {
MEMBER
MODERATOR
ADMIN
}
type User {
id: ID!
}
```
# Example schema with invalid default input field argument
This schema fails to validate because field `field` on input `ChildInput`
specifies `INVALID` as default value and `INVALID` is not a member of
the `Role` enum:
```graphql
type Query {
field(arg: Input = {field: {field: INVALID}}): String
}
input Input {
field: ChildInput
}
input ChildInput {
field: Role
}
enum Role {
USER
ADMIN
}
```
"""
GraphQLEnumsValuesValidatorVisitor(schema)


class GraphQLEnumsValuesValidatorVisitor(GraphQLASTEnumsValuesVisitor):
def visit_ast_enum_default_value(
self, location: "GraphQLASTEnumDefaultValueLocation"
):
valid_values = self.enum_values[location.enum_name]
if location.enum_value not in valid_values:
if location.arg_name:
raise ValueError(
f"Undefined enum value '{location.enum_value}' for enum "
f"'{location.enum_name}' in a default value of "
f"'{location.arg_name}' argument for '{location.field_name}' "
f"field on '{location.object_name}' type."
)

raise ValueError(
f"Undefined enum value '{location.enum_value}' for enum "
f"'{location.enum_name}' in a default value of "
f"'{location.field_name}' field on '{location.object_name}' type."
)


def repair_schema_default_enum_values(schema: GraphQLSchema) -> None:
"""Repairs Python values of default enums embedded in the GraphQL schema.
Default enum values in the GraphQL schemas are represented as strings with enum
member names in Python. Assigning custom Python values to members of the
`GraphQLEnumType` doesn't change those defaults.
This function walks the GraphQL schema, finds default enum values strings and,
if this string is a valid GraphQL member name, swaps it out for a valid Python
value.
"""
GraphQLSchemaEnumsValuesRepairVisitor(schema)


class GraphQLSchemaEnumsValuesRepairVisitor(GraphQLSchemaEnumsValuesVisitor):
def visit_schema_enum_default_value(
self, location: "GraphQLSchemaEnumDefaultValueLocation"
):
valid_values = self.enum_values[location.enum_name]
valid_default = valid_values[location.enum_value]
if location.default_value_path is not None:
location.default_value[location.default_value_path] = valid_default
elif location.arg_def:
location.arg_def.default_value = valid_default
elif isinstance(location.field_def, GraphQLInputField):
location.field_def.default_value = valid_default
561 changes: 561 additions & 0 deletions ariadne/enums_values_visitor.py

Large diffs are not rendered by default.

22 changes: 6 additions & 16 deletions ariadne/executable_schema.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,10 @@
parse,
)

from .enums import (
EnumType,
set_default_enum_values_on_schema,
validate_schema_enum_values,
from .enums import EnumType
from .enums_default_values import (
repair_schema_default_enum_values,
validate_schema_default_enum_values,
)
from .schema_names import SchemaNameConverter, convert_schema_names
from .schema_visitor import SchemaDirectiveVisitor
@@ -345,14 +345,12 @@ def uppercase_resolved_value(*args, **kwargs):
if isinstance(bindable, SchemaBindable):
bindable.bind_to_schema(schema)

set_default_enum_values_on_schema(schema)

if directives:
SchemaDirectiveVisitor.visit_schema_directives(schema, directives)

assert_valid_schema(schema)
validate_schema_enum_values(schema)
repair_default_enum_values(schema, normalized_bindables)
validate_schema_default_enum_values(schema)
repair_schema_default_enum_values(schema)

if convert_names_case:
convert_schema_names(
@@ -394,11 +392,3 @@ def flatten_bindables(
new_bindables.append(bindable)

return new_bindables


def repair_default_enum_values(
schema: GraphQLSchema, bindables: List[SchemaBindable]
) -> None:
for bindable in bindables:
if isinstance(bindable, EnumType):
bindable.bind_to_default_values(schema)
200 changes: 1 addition & 199 deletions tests/test_enums.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import re
from enum import Enum, IntEnum

import pytest
from graphql import graphql_sync, build_schema, parse
from graphql.pyutils.undefined import Undefined
from graphql.utilities.build_ast_schema import build_ast_schema
from graphql import graphql_sync, build_schema

from ariadne import EnumType, QueryType, make_executable_schema
from ariadne import graphql_sync as ariadne_graphql_sync
from ariadne.executable_schema import join_type_defs
from ariadne.enums import find_enum_values_in_schema

enum_definition = """
enum Episode {
@@ -299,183 +293,6 @@ def resolve_test_enum(*_, i):
assert result.data["complex"]


def test_input_exc_schema_raises_exception_for_undefined_enum_value_in_flat_input():
input_schema = """
type Query {
complex(i: Test = { role: EMPIRE }): String
}
input Test {
ignore: String
role: Episode = TWO_TOWERS
}
"""
with pytest.raises(
ValueError,
match=re.escape(
"Value for type: <Episode> is invalid. "
"Check InputField/Arguments for <role> in <Test> (Undefined enum value)."
),
):
make_executable_schema([enum_definition, input_schema])


def test_input_exc_schema_raises_exception_for_undefined_enum_value_in_nested_object():
input_schema = """
type Query {
complex(i: Test = { role: EMPIRE }): String
}
input Test {
ignore: String
role: Episode = EMPIRE
}
input BetterTest {
newIgnore: String
test: Test = { role: ANDRZEJU }
}
"""

with pytest.raises(
ValueError,
match=re.escape(
"Value for type: <Test> is invalid. "
"Check InputField/Arguments for <test> in <BetterTest> (Undefined enum value)."
),
):
make_executable_schema([enum_definition, input_schema])


def test_input_exc_schema_raises_exception_for_undefined_enum_value_in_nested_field_arg():
input_schema = """
type Query {
complex(i: BetterTest = { test: { role: TWO_TOWERS } }): String
}
input Test {
ignore: String
role: Episode = EMPIRE
}
input BetterTest {
newIgnore: String
test: Test = { role: NEW_HOPE }
}
"""

with pytest.raises(
ValueError,
match=re.escape(
"Value for type: <BetterTest> is invalid. "
"Check InputField/Arguments for <i> in <Query> (Undefined enum value)."
),
):
make_executable_schema([enum_definition, input_schema])


def test_find_enum_values_in_schema_for_undefined_and_invalid_values():
input_schema = """
type Query {
complex(
hello: Episode = EMPIRE,
hi: Test = { role: JEDI, next_role: ENTERPRISE, ignore: "HI"},
bonjour: BetterTest = {newIgnore: "Witam", test: { role: NEWHOPE }}
): String
}
input Test {
ignore: String
role: Episode = EMPIRE
next_role: Episode
}
input BetterTest {
newIgnore: String
test: Test = { role: NEWHOPE }
}
interface Result {
hello(r: Episode = ENTERPRISE): String
}
"""
query_complex_keys = [["next_role"], ["role"], ["test", "role"]]
better_test_complex_keys = [["role"]]
number_of_defined_enum_values = 8

# 2 Undefined because of "ENTERPRISE" invalid value, and next_role in Test input
number_of_undefined_default_enum_values = 4

ast_document = parse(join_type_defs([enum_definition, input_schema]))
schema = build_ast_schema(ast_document)
enums_entities = list(find_enum_values_in_schema(schema))
keys_to_complex_inputs = [keys for *_, keys in enums_entities if keys is not None]

undefined = [
(*_, args, keys)
for *_, args, keys in enums_entities
if args.default_value is Undefined
]

assert keys_to_complex_inputs == query_complex_keys + better_test_complex_keys
assert len(enums_entities) == number_of_defined_enum_values
assert len(undefined) == number_of_undefined_default_enum_values


def test_enum_type_is_able_to_represent_enum_default_value_in_schema():
# regression test for: https://github.com/mirumee/ariadne/issues/293
# regression test for: https://github.com/mirumee/ariadne/issues/995

type_defs = """
enum Role {
ADMIN
USER
}
type Query {
hello(r: Role = USER): String
results: [Result!]!
}
interface Result {
hello(r: Role = USER): String
}
type User implements Result {
hello(r: Role = USER): String
}
"""

class Role(Enum):
ADMIN = "admin"
USER = "user"

def resolve_test_enum(*_, r):
return r == Role.USER

RoleGraphQLType = EnumType("Role", Role)
QueryGraphQLType = QueryType()

QueryGraphQLType.set_field("hello", resolve_test_enum)

schema = make_executable_schema(
type_defs,
QueryGraphQLType,
RoleGraphQLType,
)

query = "{__schema{types{name,fields{name,args{name,defaultValue}}}}}"
_, result = ariadne_graphql_sync(schema, {"query": query}, debug=True)
assert not result.get("errors")

types_map = {
result_type["name"]: result_type
for result_type in result["data"]["__schema"]["types"]
}
assert schema.type_map["Query"].fields["hello"].args["r"].default_value == Role.USER

result_hello_query = graphql_sync(schema, "{hello}")
assert types_map["Query"]["fields"][0]["args"][0]["defaultValue"] == "USER"
assert types_map["User"]["fields"][0]["args"][0]["defaultValue"] == "USER"
assert result_hello_query.data["hello"]
assert result_hello_query.errors is None


def test_python_enums_can_be_passed_directly_to_make_executable_schema():
class Episode(Enum):
NEWHOPE = "new-hope"
@@ -515,18 +332,3 @@ class UnknownEnum(str, Enum):

with pytest.raises(ValueError):
make_executable_schema([enum_definition, enum_field], query, UnknownEnum)


def test_schema_enum_values_fixer_handles_null_input_default():
# regression test for: https://github.com/mirumee/ariadne/issues/1074
make_executable_schema(
"""
input SearchInput {
name: String!
}
type Query {
search(filters: SearchInput = null): String
}
"""
)
588 changes: 588 additions & 0 deletions tests/test_graphql_enum_fixes.py

Large diffs are not rendered by default.

0 comments on commit c4cdefb

Please sign in to comment.