Skip to content

Commit

Permalink
feat(relay): Separate slice metadata to a separate function to allow …
Browse files Browse the repository at this point in the history
…reusage
  • Loading branch information
bellini666 committed Jun 1, 2024
1 parent 2c90ef6 commit 08bfdd2
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 74 deletions.
11 changes: 11 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Release type: minor

This release separates the `relay.ListConnection` logic that calculates the
slice of the nodes into a separate function.

This allows for easier reuse of that logic for other places/libraries.

The new function lives in the `strawberry.relay.utils` module and is called
`get_slice_metadata`.

This has no implications to end users.
107 changes: 36 additions & 71 deletions strawberry/relay/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
from strawberry.utils.inspect import in_async_context
from strawberry.utils.typing import eval_type, is_classvar

from .utils import from_base64, should_resolve_list_connection_edges, to_base64
from .utils import (
from_base64,
get_slice_metadata,
should_resolve_list_connection_edges,
to_base64,
)

if TYPE_CHECKING:
from strawberry.scalars import ID
Expand Down Expand Up @@ -790,61 +795,13 @@ def resolve_connection(
https://relay.dev/graphql/connections.htm#sec-Pagination-algorithm
"""
max_results = info.schema.config.relay_max_results
start = 0
end: Optional[int] = None

if after:
after_type, after_parsed = from_base64(after)
if after_type != PREFIX:
# When the base64 hash doesnt exist, the after_type seems to return
# arrayconnEction instead of PREFIX. Let's raise a predictable
# instead of "An unknown error occurred."
raise TypeError("Argument 'after' contains a non-existing value.")

start = int(after_parsed) + 1
if before:
before_type, before_parsed = from_base64(before)
if before_type != PREFIX:
# When the base64 hash doesnt exist, the after_type seems to return
# arrayconnEction instead of PREFIX. Let's raise a predictable
# instead of "An unknown error occurred.
raise TypeError("Argument 'before' contains a non-existing value.")
end = int(before_parsed)

if isinstance(first, int):
if first < 0:
raise ValueError("Argument 'first' must be a non-negative integer.")

if first > max_results:
raise ValueError(
f"Argument 'first' cannot be higher than {max_results}."
)

if end is not None:
start = max(0, end - 1)

end = start + first
if isinstance(last, int):
if last < 0:
raise ValueError("Argument 'last' must be a non-negative integer.")

if last > max_results:
raise ValueError(
f"Argument 'last' cannot be higher than {max_results}."
)

if end is not None:
start = max(start, end - last)
else:
end = sys.maxsize

if end is None:
end = start + max_results

expected = end - start if end != sys.maxsize else None
# Overfetch by 1 to check if we have a next result
overfetch = end + 1 if end != sys.maxsize else end
slice_metadata = get_slice_metadata(
info,
before=before,
after=after,
first=first,
last=last,
)

type_def = get_object_definition(cls)
assert type_def
Expand All @@ -863,15 +820,17 @@ async def resolver() -> Self:
try:
iterator = cast(
Union[AsyncIterator[NodeType], AsyncIterable[NodeType]],
cast(Sequence, nodes)[start:overfetch],
cast(Sequence, nodes)[
slice_metadata.start : slice_metadata.overfetch
],
)
except TypeError:
# TODO: Why mypy isn't narrowing this based on the if above?
assert isinstance(nodes, (AsyncIterator, AsyncIterable))
iterator = aislice(
nodes,
start,
overfetch,
slice_metadata.start,
slice_metadata.overfetch,
)

# The slice above might return an object that now is not async
Expand All @@ -880,25 +839,28 @@ async def resolver() -> Self:
edges: List[Edge] = [
edge_class.resolve_edge(
cls.resolve_node(v, info=info, **kwargs),
cursor=start + i,
cursor=slice_metadata.start + i,
)
async for i, v in aenumerate(iterator)
]
else:
edges: List[Edge] = [ # type: ignore[no-redef]
edge_class.resolve_edge(
cls.resolve_node(v, info=info, **kwargs),
cursor=start + i,
cursor=slice_metadata.start + i,
)
for i, v in enumerate(iterator)
]

has_previous_page = start > 0
if expected is not None and len(edges) == expected + 1:
has_previous_page = slice_metadata.start > 0
if (
slice_metadata.expected is not None
and len(edges) == slice_metadata.expected + 1
):
# Remove the overfetched result
edges = edges[:-1]
has_next_page = True
elif end == sys.maxsize:
elif slice_metadata.end == sys.maxsize:
# Last was asked without any after/before
assert last is not None
original_len = len(edges)
Expand All @@ -923,14 +885,14 @@ async def resolver() -> Self:
try:
iterator = cast(
Union[Iterator[NodeType], Iterable[NodeType]],
cast(Sequence, nodes)[start:overfetch],
cast(Sequence, nodes)[slice_metadata.start : slice_metadata.overfetch],
)
except TypeError:
assert isinstance(nodes, (Iterable, Iterator))
iterator = itertools.islice(
nodes,
start,
overfetch,
slice_metadata.start,
slice_metadata.overfetch,
)

if not should_resolve_list_connection_edges(info):
Expand All @@ -947,17 +909,20 @@ async def resolver() -> Self:
edges = [
edge_class.resolve_edge(
cls.resolve_node(v, info=info, **kwargs),
cursor=start + i,
cursor=slice_metadata.start + i,
)
for i, v in enumerate(iterator)
]

has_previous_page = start > 0
if expected is not None and len(edges) == expected + 1:
has_previous_page = slice_metadata.start > 0
if (
slice_metadata.expected is not None
and len(edges) == slice_metadata.expected + 1
):
# Remove the overfetched result
edges = edges[:-1]
has_next_page = True
elif end == sys.maxsize:
elif slice_metadata.end == sys.maxsize:
# Last was asked without any after/before
assert last is not None
original_len = len(edges)
Expand Down
90 changes: 88 additions & 2 deletions strawberry/relay/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations
import sys

import base64
from typing import Any, Tuple, Union
import dataclasses
from typing import TYPE_CHECKING, Any, Tuple, Union
from typing_extensions import assert_never

from strawberry.types.info import Info
from strawberry.types.nodes import InlineFragment, Selection
from strawberry.types.types import StrawberryObjectDefinition

if TYPE_CHECKING:
from strawberry.types.info import Info


def from_base64(value: str) -> Tuple[str, str]:
"""Parse the base64 encoded relay value.
Expand Down Expand Up @@ -102,3 +108,83 @@ def _check_selection(selection: Selection) -> bool:
if _check_selection(selection):
return True
return False


@dataclasses.dataclass
class SliceMetadata:
start: int
end: int
expected: int | None

@property
def overfetch(self) -> int:
# Overfetch by 1 to check if we have a next result
return self.end + 1 if self.end != sys.maxsize else self.end


def get_slice_metadata(
info: Info,
*,
before: str | None = None,
after: str | None = None,
first: int | None = None,
last: int | None = None,
) -> SliceMetadata:
"""Get the slice metadata to use on ListConnection."""
from strawberry.relay.types import PREFIX

max_results = info.schema.config.relay_max_results
start = 0
end: int | None = None

if after:
after_type, after_parsed = from_base64(after)
if after_type != PREFIX:
# When the base64 hash doesnt exist, the after_type seems to return
# arrayconnEction instead of PREFIX. Let's raise a predictable
# instead of "An unknown error occurred."
raise TypeError("Argument 'after' contains a non-existing value.")

start = int(after_parsed) + 1
if before:
before_type, before_parsed = from_base64(before)
if before_type != PREFIX:
# When the base64 hash doesnt exist, the after_type seems to return
# arrayconnEction instead of PREFIX. Let's raise a predictable
# instead of "An unknown error occurred.
raise TypeError("Argument 'before' contains a non-existing value.")
end = int(before_parsed)

if isinstance(first, int):
if first < 0:
raise ValueError("Argument 'first' must be a non-negative integer.")

if first > max_results:
raise ValueError(f"Argument 'first' cannot be higher than {max_results}.")

if end is not None:
start = max(0, end - 1)

end = start + first
if isinstance(last, int):
if last < 0:
raise ValueError("Argument 'last' must be a non-negative integer.")

if last > max_results:
raise ValueError(f"Argument 'last' cannot be higher than {max_results}.")

if end is not None:
start = max(start, end - last)
else:
end = sys.maxsize

if end is None:
end = start + max_results

expected = end - start if end != sys.maxsize else None

return SliceMetadata(
start=start,
end=end,
expected=expected,
)
Loading

0 comments on commit 08bfdd2

Please sign in to comment.