diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..c2802a2940 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,11 @@ +Release type: minor + +This release separates the `relay.ListConnection` logic that calculates the +slice of the nodes into a separate function. + +This allows for easier reuse of that logic for other places/libraries. + +The new function lives in the `strawberry.relay.utils` module and is called +`get_slice_metadata`. + +This has no implications to end users. diff --git a/strawberry/relay/types.py b/strawberry/relay/types.py index 5ee9a92b6e..54d8f73816 100644 --- a/strawberry/relay/types.py +++ b/strawberry/relay/types.py @@ -38,7 +38,12 @@ from strawberry.utils.inspect import in_async_context from strawberry.utils.typing import eval_type, is_classvar -from .utils import from_base64, should_resolve_list_connection_edges, to_base64 +from .utils import ( + from_base64, + get_slice_metadata, + should_resolve_list_connection_edges, + to_base64, +) if TYPE_CHECKING: from strawberry.scalars import ID @@ -790,61 +795,13 @@ def resolve_connection( https://relay.dev/graphql/connections.htm#sec-Pagination-algorithm """ - max_results = info.schema.config.relay_max_results - start = 0 - end: Optional[int] = None - - if after: - after_type, after_parsed = from_base64(after) - if after_type != PREFIX: - # When the base64 hash doesnt exist, the after_type seems to return - # arrayconnEction instead of PREFIX. Let's raise a predictable - # instead of "An unknown error occurred." - raise TypeError("Argument 'after' contains a non-existing value.") - - start = int(after_parsed) + 1 - if before: - before_type, before_parsed = from_base64(before) - if before_type != PREFIX: - # When the base64 hash doesnt exist, the after_type seems to return - # arrayconnEction instead of PREFIX. Let's raise a predictable - # instead of "An unknown error occurred. - raise TypeError("Argument 'before' contains a non-existing value.") - end = int(before_parsed) - - if isinstance(first, int): - if first < 0: - raise ValueError("Argument 'first' must be a non-negative integer.") - - if first > max_results: - raise ValueError( - f"Argument 'first' cannot be higher than {max_results}." - ) - - if end is not None: - start = max(0, end - 1) - - end = start + first - if isinstance(last, int): - if last < 0: - raise ValueError("Argument 'last' must be a non-negative integer.") - - if last > max_results: - raise ValueError( - f"Argument 'last' cannot be higher than {max_results}." - ) - - if end is not None: - start = max(start, end - last) - else: - end = sys.maxsize - - if end is None: - end = start + max_results - - expected = end - start if end != sys.maxsize else None - # Overfetch by 1 to check if we have a next result - overfetch = end + 1 if end != sys.maxsize else end + slice_metadata = get_slice_metadata( + info, + before=before, + after=after, + first=first, + last=last, + ) type_def = get_object_definition(cls) assert type_def @@ -863,15 +820,17 @@ async def resolver() -> Self: try: iterator = cast( Union[AsyncIterator[NodeType], AsyncIterable[NodeType]], - cast(Sequence, nodes)[start:overfetch], + cast(Sequence, nodes)[ + slice_metadata.start : slice_metadata.overfetch + ], ) except TypeError: # TODO: Why mypy isn't narrowing this based on the if above? assert isinstance(nodes, (AsyncIterator, AsyncIterable)) iterator = aislice( nodes, - start, - overfetch, + slice_metadata.start, + slice_metadata.overfetch, ) # The slice above might return an object that now is not async @@ -880,7 +839,7 @@ async def resolver() -> Self: edges: List[Edge] = [ edge_class.resolve_edge( cls.resolve_node(v, info=info, **kwargs), - cursor=start + i, + cursor=slice_metadata.start + i, ) async for i, v in aenumerate(iterator) ] @@ -888,17 +847,20 @@ async def resolver() -> Self: edges: List[Edge] = [ # type: ignore[no-redef] edge_class.resolve_edge( cls.resolve_node(v, info=info, **kwargs), - cursor=start + i, + cursor=slice_metadata.start + i, ) for i, v in enumerate(iterator) ] - has_previous_page = start > 0 - if expected is not None and len(edges) == expected + 1: + has_previous_page = slice_metadata.start > 0 + if ( + slice_metadata.expected is not None + and len(edges) == slice_metadata.expected + 1 + ): # Remove the overfetched result edges = edges[:-1] has_next_page = True - elif end == sys.maxsize: + elif slice_metadata.end == sys.maxsize: # Last was asked without any after/before assert last is not None original_len = len(edges) @@ -923,14 +885,14 @@ async def resolver() -> Self: try: iterator = cast( Union[Iterator[NodeType], Iterable[NodeType]], - cast(Sequence, nodes)[start:overfetch], + cast(Sequence, nodes)[slice_metadata.start : slice_metadata.overfetch], ) except TypeError: assert isinstance(nodes, (Iterable, Iterator)) iterator = itertools.islice( nodes, - start, - overfetch, + slice_metadata.start, + slice_metadata.overfetch, ) if not should_resolve_list_connection_edges(info): @@ -947,17 +909,20 @@ async def resolver() -> Self: edges = [ edge_class.resolve_edge( cls.resolve_node(v, info=info, **kwargs), - cursor=start + i, + cursor=slice_metadata.start + i, ) for i, v in enumerate(iterator) ] - has_previous_page = start > 0 - if expected is not None and len(edges) == expected + 1: + has_previous_page = slice_metadata.start > 0 + if ( + slice_metadata.expected is not None + and len(edges) == slice_metadata.expected + 1 + ): # Remove the overfetched result edges = edges[:-1] has_next_page = True - elif end == sys.maxsize: + elif slice_metadata.end == sys.maxsize: # Last was asked without any after/before assert last is not None original_len = len(edges) diff --git a/strawberry/relay/utils.py b/strawberry/relay/utils.py index 894eb8d971..5f5103159e 100644 --- a/strawberry/relay/utils.py +++ b/strawberry/relay/utils.py @@ -1,11 +1,17 @@ +from __future__ import annotations +import sys + import base64 -from typing import Any, Tuple, Union +import dataclasses +from typing import TYPE_CHECKING, Any, Tuple, Union from typing_extensions import assert_never -from strawberry.types.info import Info from strawberry.types.nodes import InlineFragment, Selection from strawberry.types.types import StrawberryObjectDefinition +if TYPE_CHECKING: + from strawberry.types.info import Info + def from_base64(value: str) -> Tuple[str, str]: """Parse the base64 encoded relay value. @@ -102,3 +108,83 @@ def _check_selection(selection: Selection) -> bool: if _check_selection(selection): return True return False + + +@dataclasses.dataclass +class SliceMetadata: + start: int + end: int + expected: int | None + + @property + def overfetch(self) -> int: + # Overfetch by 1 to check if we have a next result + return self.end + 1 if self.end != sys.maxsize else self.end + + +def get_slice_metadata( + info: Info, + *, + before: str | None = None, + after: str | None = None, + first: int | None = None, + last: int | None = None, +) -> SliceMetadata: + """Get the slice metadata to use on ListConnection.""" + from strawberry.relay.types import PREFIX + + max_results = info.schema.config.relay_max_results + start = 0 + end: int | None = None + + if after: + after_type, after_parsed = from_base64(after) + if after_type != PREFIX: + # When the base64 hash doesnt exist, the after_type seems to return + # arrayconnEction instead of PREFIX. Let's raise a predictable + # instead of "An unknown error occurred." + raise TypeError("Argument 'after' contains a non-existing value.") + + start = int(after_parsed) + 1 + if before: + before_type, before_parsed = from_base64(before) + if before_type != PREFIX: + # When the base64 hash doesnt exist, the after_type seems to return + # arrayconnEction instead of PREFIX. Let's raise a predictable + # instead of "An unknown error occurred. + raise TypeError("Argument 'before' contains a non-existing value.") + end = int(before_parsed) + + if isinstance(first, int): + if first < 0: + raise ValueError("Argument 'first' must be a non-negative integer.") + + if first > max_results: + raise ValueError(f"Argument 'first' cannot be higher than {max_results}.") + + if end is not None: + start = max(0, end - 1) + + end = start + first + if isinstance(last, int): + if last < 0: + raise ValueError("Argument 'last' must be a non-negative integer.") + + if last > max_results: + raise ValueError(f"Argument 'last' cannot be higher than {max_results}.") + + if end is not None: + start = max(start, end - last) + else: + end = sys.maxsize + + if end is None: + end = start + max_results + + expected = end - start if end != sys.maxsize else None + + return SliceMetadata( + start=start, + end=end, + expected=expected, + ) diff --git a/tests/relay/test_utils.py b/tests/relay/test_utils.py index 8811df8f0a..f9365a385a 100644 --- a/tests/relay/test_utils.py +++ b/tests/relay/test_utils.py @@ -1,10 +1,22 @@ +from __future__ import annotations + +import sys from typing import Any +from unittest import mock import pytest -from strawberry.relay.utils import from_base64, to_base64 +from strawberry.relay.utils import ( + SliceMetadata, + from_base64, + get_slice_metadata, + to_base64, +) +from strawberry.schema.config import StrawberryConfig from strawberry.type import get_object_definition +from strawberry.relay.types import PREFIX + from .schema import Fruit @@ -61,3 +73,101 @@ def test_to_base64_with_typedef(): def test_to_base64_with_invalid_type(value: Any): with pytest.raises(ValueError): value = to_base64(value, "1") + + +@pytest.mark.parametrize( + ( + "before", + "after", + "first", + "last", + "max_results", + "expected", + "expected_overfetch", + ), + [ + ( + None, + None, + None, + None, + 100, + SliceMetadata(start=0, end=100, expected=100), + 101, + ), + ( + None, + None, + 10, + None, + 100, + SliceMetadata(start=0, end=10, expected=10), + 11, + ), + ( + None, + None, + None, + 10, + 100, + SliceMetadata(start=0, end=sys.maxsize, expected=None), + sys.maxsize, + ), + ( + 10, + None, + None, + None, + 100, + SliceMetadata(start=0, end=10, expected=10), + 11, + ), + ( + None, + 10, + None, + None, + 100, + SliceMetadata(start=11, end=111, expected=100), + 112, + ), + ( + 15, + None, + 10, + None, + 100, + SliceMetadata(start=14, end=24, expected=10), + 25, + ), + ( + None, + 15, + None, + 10, + 100, + SliceMetadata(start=16, end=sys.maxsize, expected=None), + sys.maxsize, + ), + ], +) +def test_get_slice_metadata( + before: str | None, + after: str | None, + first: int | None, + last: int | None, + max_results: int, + expected: SliceMetadata, + expected_overfetch: int, +): + info = mock.Mock() + info.schema.config = StrawberryConfig(relay_max_results=max_results) + slice_metadata = get_slice_metadata( + info, + before=before and to_base64(PREFIX, before), + after=after and to_base64(PREFIX, after), + first=first, + last=last, + ) + assert slice_metadata == expected + assert slice_metadata.overfetch == expected_overfetch