Skip to content

Commit

Permalink
refactor handleFieldError
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Apr 7, 2024
1 parent afa6b93 commit bce7a3d
Showing 1 changed file with 70 additions and 44 deletions.
114 changes: 70 additions & 44 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,6 @@ def execute_field(
calling its resolve function, then calls complete_value to await coroutine
objects, serialize scalars, or execute the sub-selection-set for objects.
"""
errors = async_payload_record.errors if async_payload_record else self.errors
field_name = field_group[0].name.value
field_def = self.schema.get_field(parent_type, field_name)
if not field_def:
Expand Down Expand Up @@ -983,16 +982,26 @@ async def await_completed() -> Any:
try:
return await completed
except Exception as raw_error:
error = located_error(raw_error, field_group, path.as_list())
handle_field_error(error, return_type, errors)
self.handle_field_error(
raw_error,
return_type,
field_group,
path,
async_payload_record,
)
self.filter_subsequent_payloads(path, async_payload_record)
return None

return await_completed()

except Exception as raw_error:
error = located_error(raw_error, field_group, path.as_list())
handle_field_error(error, return_type, errors)
self.handle_field_error(
raw_error,
return_type,
field_group,
path,
async_payload_record,
)
self.filter_subsequent_payloads(path, async_payload_record)
return None

Expand Down Expand Up @@ -1026,6 +1035,28 @@ def build_resolve_info(
self.is_awaitable,
)

def handle_field_error(
self,
raw_error: Exception,
return_type: GraphQLOutputType,
field_group: FieldGroup,
path: Path,
async_payload_record: AsyncPayloadRecord | None = None,
) -> None:
"""Handle error properly according to the field type."""
error = located_error(raw_error, field_group, path.as_list())

# If the field type is non-nullable, then it is resolved without any protection
# from errors, however it still properly locates the error.
if is_non_null_type(return_type):
raise error

errors = async_payload_record.errors if async_payload_record else self.errors

# Otherwise, error protection is applied, logging the error and resolving a
# null value for this field if one is encountered.
errors.append(error)

def complete_value(
self,
return_type: GraphQLOutputType,
Expand Down Expand Up @@ -1138,11 +1169,9 @@ async def complete_awaitable_value(
if self.is_awaitable(completed):
completed = await completed
except Exception as raw_error:
errors = (
async_payload_record.errors if async_payload_record else self.errors
self.handle_field_error(
raw_error, return_type, field_group, path, async_payload_record
)
error = located_error(raw_error, field_group, path.as_list())
handle_field_error(error, return_type, errors)
self.filter_subsequent_payloads(path, async_payload_record)
completed = None
return completed
Expand Down Expand Up @@ -1198,7 +1227,6 @@ async def complete_async_iterator_value(
Complete an async iterator value by completing the result and calling
recursively until all the results are completed.
"""
errors = async_payload_record.errors if async_payload_record else self.errors
stream = self.get_stream_values(field_group, path)
complete_list_item_value = self.complete_list_item_value
awaitable_indices: list[int] = []
Expand Down Expand Up @@ -1236,14 +1264,14 @@ async def complete_async_iterator_value(
except StopAsyncIteration:
break
except Exception as raw_error:
error = located_error(raw_error, field_group, item_path.as_list())
handle_field_error(error, item_type, errors)
self.handle_field_error(
raw_error, item_type, field_group, item_path, async_payload_record
)
completed_results.append(None)
break
if complete_list_item_value(
value,
completed_results,
errors,
item_type,
field_group,
info,
Expand Down Expand Up @@ -1285,7 +1313,6 @@ def complete_list_value(
Complete a list value by completing each item in the list with the inner type.
"""
item_type = return_type.of_type
errors = async_payload_record.errors if async_payload_record else self.errors

if isinstance(result, AsyncIterable):
iterator = result.__aiter__()
Expand Down Expand Up @@ -1336,7 +1363,6 @@ def complete_list_value(
if complete_list_item_value(
item,
completed_results,
errors,
item_type,
field_group,
info,
Expand Down Expand Up @@ -1370,7 +1396,6 @@ def complete_list_item_value(
self,
item: Any,
complete_results: list[Any],
errors: list[GraphQLError],
item_type: GraphQLOutputType,
field_group: FieldGroup,
info: GraphQLResolveInfo,
Expand Down Expand Up @@ -1407,10 +1432,13 @@ async def await_completed() -> Any:
try:
return await completed_item
except Exception as raw_error:
error = located_error(
raw_error, field_group, item_path.as_list()
self.handle_field_error(
raw_error,
item_type,
field_group,
item_path,
async_payload_record,
)
handle_field_error(error, item_type, errors)
self.filter_subsequent_payloads(item_path, async_payload_record)
return None

Expand All @@ -1420,8 +1448,13 @@ async def await_completed() -> Any:
complete_results.append(completed_item)

except Exception as raw_error:
error = located_error(raw_error, field_group, item_path.as_list())
handle_field_error(error, item_type, errors)
self.handle_field_error(
raw_error,
item_type,
field_group,
item_path,
async_payload_record,
)
self.filter_subsequent_payloads(item_path, async_payload_record)
complete_results.append(None)

Expand Down Expand Up @@ -1787,12 +1820,12 @@ async def await_completed_items() -> list[Any] | None:
try:
return [await completed_item]
except Exception as raw_error: # pragma: no cover
# noinspection PyShadowingNames
error = located_error(
raw_error, field_group, item_path.as_list()
)
handle_field_error(
error, item_type, async_payload_record.errors
self.handle_field_error(
raw_error,
item_type,
field_group,
item_path,
async_payload_record,
)
self.filter_subsequent_payloads(
item_path, async_payload_record
Expand All @@ -1808,8 +1841,13 @@ async def await_completed_items() -> list[Any] | None:
completed_items = [completed_item]

except Exception as raw_error:
error = located_error(raw_error, field_group, item_path.as_list())
handle_field_error(error, item_type, async_payload_record.errors)
self.handle_field_error(
raw_error,
item_type,
field_group,
item_path,
async_payload_record,
)
self.filter_subsequent_payloads(item_path, async_payload_record)
completed_items = [None]

Expand Down Expand Up @@ -1850,8 +1888,9 @@ async def execute_stream_iterator_item(
raise StopAsyncIteration from raw_error

except Exception as raw_error:
error = located_error(raw_error, field_group, item_path.as_list())
handle_field_error(error, item_type, async_payload_record.errors)
self.handle_field_error(
raw_error, item_type, field_group, item_path, async_payload_record
)
self.filter_subsequent_payloads(item_path, async_payload_record)

async def execute_stream_iterator(
Expand Down Expand Up @@ -2231,19 +2270,6 @@ def execute_sync(
return cast(ExecutionResult, result)


def handle_field_error(
error: GraphQLError, return_type: GraphQLOutputType, errors: list[GraphQLError]
) -> None:
"""Handle error properly according to the field type."""
# If the field type is non-nullable, then it is resolved without any protection
# from errors, however it still properly locates the error.
if is_non_null_type(return_type):
raise error
# Otherwise, error protection is applied, logging the error and resolving a
# null value for this field if one is encountered.
errors.append(error)


def invalid_return_type_error(
return_type: GraphQLObjectType, result: Any, field_group: FieldGroup
) -> GraphQLError:
Expand Down

0 comments on commit bce7a3d

Please sign in to comment.