Skip to content

Commit

Permalink
introduce FieldGroup type
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Apr 7, 2024
1 parent e8559b0 commit f4d5501
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
AwaitableOrValue
EnterLeaveVisitor
ExperimentalIncrementalExecutionResults
FieldGroup
FormattedSourceLocation
GraphQLAbstractType
GraphQLCompositeType
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ exclude_lines = [
"pragma: no cover",
"except ImportError:",
"# Python <",
'sys\.version_info <',
"raise NotImplementedError",
"assert False,",
'\s+next\($',
Expand Down
22 changes: 17 additions & 5 deletions src/graphql/execution/collect_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from __future__ import annotations

import sys
from collections import defaultdict
from typing import Any, NamedTuple
from typing import Any, List, NamedTuple

from ..language import (
FieldNode,
Expand All @@ -25,20 +26,31 @@
from ..utilities.type_from_ast import type_from_ast
from .values import get_directive_values

__all__ = ["collect_fields", "collect_subfields", "FieldsAndPatches"]
try:
from typing import TypeAlias
except ImportError: # Python < 3.10
from typing_extensions import TypeAlias


__all__ = ["collect_fields", "collect_subfields", "FieldGroup", "FieldsAndPatches"]

if sys.version_info < (3, 9):
FieldGroup: TypeAlias = List[FieldNode]
else: # Python >= 3.9
FieldGroup: TypeAlias = list[FieldNode]


class PatchFields(NamedTuple):
"""Optionally labelled set of fields to be used as a patch."""

label: str | None
fields: dict[str, list[FieldNode]]
fields: dict[str, FieldGroup]


class FieldsAndPatches(NamedTuple):
"""Tuple of collected fields and patches to be applied."""

fields: dict[str, list[FieldNode]]
fields: dict[str, FieldGroup]
patches: list[PatchFields]


Expand Down Expand Up @@ -81,7 +93,7 @@ def collect_subfields(
variable_values: dict[str, Any],
operation: OperationDefinitionNode,
return_type: GraphQLObjectType,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
) -> FieldsAndPatches:
"""Collect subfields.
Expand Down
50 changes: 27 additions & 23 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from ..error import GraphQLError, GraphQLFormattedError, located_error
from ..language import (
DocumentNode,
FieldNode,
FragmentDefinitionNode,
OperationDefinitionNode,
OperationType,
Expand Down Expand Up @@ -75,7 +74,12 @@
is_object_type,
)
from .async_iterables import map_async_iterable
from .collect_fields import FieldsAndPatches, collect_fields, collect_subfields
from .collect_fields import (
FieldGroup,
FieldsAndPatches,
collect_fields,
collect_subfields,
)
from .middleware import MiddlewareManager
from .values import get_argument_values, get_directive_values, get_variable_values

Expand Down Expand Up @@ -837,7 +841,7 @@ def execute_fields_serially(
parent_type: GraphQLObjectType,
source_value: Any,
path: Path | None,
fields: dict[str, list[FieldNode]],
fields: dict[str, FieldGroup],
) -> AwaitableOrValue[dict[str, Any]]:
"""Execute the given fields serially.
Expand All @@ -847,7 +851,7 @@ def execute_fields_serially(
is_awaitable = self.is_awaitable

def reducer(
results: dict[str, Any], field_item: tuple[str, list[FieldNode]]
results: dict[str, Any], field_item: tuple[str, FieldGroup]
) -> AwaitableOrValue[dict[str, Any]]:
response_name, field_nodes = field_item
field_path = Path(path, response_name, parent_type.name)
Expand Down Expand Up @@ -877,7 +881,7 @@ def execute_fields(
parent_type: GraphQLObjectType,
source_value: Any,
path: Path | None,
fields: dict[str, list[FieldNode]],
fields: dict[str, FieldGroup],
async_payload_record: AsyncPayloadRecord | None = None,
) -> AwaitableOrValue[dict[str, Any]]:
"""Execute the given fields concurrently.
Expand Down Expand Up @@ -927,7 +931,7 @@ def execute_field(
self,
parent_type: GraphQLObjectType,
source: Any,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
path: Path,
async_payload_record: AsyncPayloadRecord | None = None,
) -> AwaitableOrValue[Any]:
Expand Down Expand Up @@ -996,7 +1000,7 @@ async def await_completed() -> Any:
def build_resolve_info(
self,
field_def: GraphQLField,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
parent_type: GraphQLObjectType,
path: Path,
) -> GraphQLResolveInfo:
Expand Down Expand Up @@ -1024,7 +1028,7 @@ def build_resolve_info(
def complete_value(
self,
return_type: GraphQLOutputType,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
path: Path,
result: Any,
Expand Down Expand Up @@ -1113,7 +1117,7 @@ def complete_value(
async def complete_awaitable_value(
self,
return_type: GraphQLOutputType,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
path: Path,
result: Any,
Expand Down Expand Up @@ -1143,7 +1147,7 @@ async def complete_awaitable_value(
return completed

def get_stream_values(
self, field_nodes: list[FieldNode], path: Path
self, field_nodes: FieldGroup, path: Path
) -> StreamArguments | None:
"""Get stream values.
Expand Down Expand Up @@ -1182,7 +1186,7 @@ def get_stream_values(
async def complete_async_iterator_value(
self,
item_type: GraphQLOutputType,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
path: Path,
iterator: AsyncIterator[Any],
Expand Down Expand Up @@ -1269,7 +1273,7 @@ async def complete_async_iterator_value(
def complete_list_value(
self,
return_type: GraphQLList[GraphQLOutputType],
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
path: Path,
result: AsyncIterable[Any] | Iterable[Any],
Expand Down Expand Up @@ -1367,7 +1371,7 @@ def complete_list_item_value(
complete_results: list[Any],
errors: list[GraphQLError],
item_type: GraphQLOutputType,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
item_path: Path,
async_payload_record: AsyncPayloadRecord | None,
Expand Down Expand Up @@ -1442,7 +1446,7 @@ def complete_leaf_value(return_type: GraphQLLeafType, result: Any) -> Any:
def complete_abstract_value(
self,
return_type: GraphQLAbstractType,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
path: Path,
result: Any,
Expand Down Expand Up @@ -1496,7 +1500,7 @@ def ensure_valid_runtime_type(
self,
runtime_type_name: Any,
return_type: GraphQLAbstractType,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
result: Any,
) -> GraphQLObjectType:
Expand Down Expand Up @@ -1557,7 +1561,7 @@ def ensure_valid_runtime_type(
def complete_object_value(
self,
return_type: GraphQLObjectType,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
path: Path,
result: Any,
Expand Down Expand Up @@ -1593,7 +1597,7 @@ async def execute_subfields_async() -> dict[str, Any]:
def collect_and_execute_subfields(
self,
return_type: GraphQLObjectType,
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
path: Path,
result: Any,
async_payload_record: AsyncPayloadRecord | None,
Expand All @@ -1619,7 +1623,7 @@ def collect_and_execute_subfields(
return sub_fields

def collect_subfields(
self, return_type: GraphQLObjectType, field_nodes: list[FieldNode]
self, return_type: GraphQLObjectType, field_nodes: FieldGroup
) -> FieldsAndPatches:
"""Collect subfields.
Expand Down Expand Up @@ -1688,7 +1692,7 @@ def execute_deferred_fragment(
self,
parent_type: GraphQLObjectType,
source_value: Any,
fields: dict[str, list[FieldNode]],
fields: dict[str, FieldGroup],
label: str | None = None,
path: Path | None = None,
parent_context: AsyncPayloadRecord | None = None,
Expand Down Expand Up @@ -1724,7 +1728,7 @@ def execute_stream_field(
path: Path,
item_path: Path,
item: AwaitableOrValue[Any],
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
item_type: GraphQLOutputType,
label: str | None = None,
Expand Down Expand Up @@ -1817,7 +1821,7 @@ async def await_completed_items() -> list[Any] | None:
async def execute_stream_iterator_item(
self,
iterator: AsyncIterator[Any],
field_nodes: list[FieldNode],
field_nodes: FieldGroup,
info: GraphQLResolveInfo,
item_type: GraphQLOutputType,
async_payload_record: StreamRecord,
Expand Down Expand Up @@ -1851,7 +1855,7 @@ async def execute_stream_iterator(
self,
initial_index: int,
iterator: AsyncIterator[Any],
field_modes: list[FieldNode],
field_modes: FieldGroup,
info: GraphQLResolveInfo,
item_type: GraphQLOutputType,
path: Path,
Expand Down Expand Up @@ -2238,7 +2242,7 @@ def handle_field_error(


def invalid_return_type_error(
return_type: GraphQLObjectType, result: Any, field_nodes: list[FieldNode]
return_type: GraphQLObjectType, result: Any, field_nodes: FieldGroup
) -> GraphQLError:
"""Create a GraphQLError for an invalid return type."""
return GraphQLError(
Expand Down

0 comments on commit f4d5501

Please sign in to comment.