Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(relay): Allow to customize max_results per connection in relay #3746

Merged
merged 2 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Release type: minor

Add the ability to override the "max results" a relay's connection can return on
a per-field basis.

The default value for this is defined in the schema's config, and set to `100`
unless modified by the user. Now, that per-field value will take precedence over
it.

For example:

```python
@strawerry.type
class Query:
# This will still use the default value in the schema's config
fruits: ListConnection[Fruit] = relay.connection()

# This will reduce the maximum number of results to 10
limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10)

# This will increase the maximum number of results to 10
higher_limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10_000)
Comment on lines +21 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a typo, right?

Suggested change
# This will increase the maximum number of results to 10
higher_limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10_000)
# This will increase the maximum number of results to 10000
higher_limited_fruits: ListConnection[Fruit] = relay.connection(max_results=10_000)

```

Note that this only affects `ListConnection` and subclasses. If you are
implementing your own connection resolver, there's an extra keyword named
`max_results: int | None` that will be passed to it.
17 changes: 16 additions & 1 deletion docs/guides/relay.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,24 @@ It can be defined in the `Query` objects in 4 ways:
- `node: List[Optional[Node]]`: The same as `List[Node]`, but the returned list
can contain `null` values if the given objects don't exist.

### Max results for connections

The implementation of `relay.ListConnection` will limit the number of results to
the `relay_max_results` configuration in the
[schema's config](../types/schema-configurations.md) (which defaults to `100`).

That can also be configured on a per-field basis by passing `max_results` to the
`@connection` decorator. For example:

```python
@strawerry.type
class Query:
fruits: ListConnection[Fruit] = relay.connection(max_results=10_000)
```

### Custom connection pagination

The default `relay.Connection` class don't implement any pagination logic, and
The default `relay.Connection` class doesn't implement any pagination logic, and
should be used as a base class to implement your own pagination logic. All you
need to do is implement the `resolve_connection` classmethod.

Expand Down
11 changes: 10 additions & 1 deletion strawberry/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ async def resolve(resolved: Any = resolved_nodes) -> list[Node]:
class ConnectionExtension(FieldExtension):
connection_type: type[Connection[Node]]

def __init__(self, max_results: Optional[int] = None) -> None:
self.max_results = max_results

def apply(self, field: StrawberryField) -> None:
field.arguments = [
*field.arguments,
Expand Down Expand Up @@ -288,6 +291,7 @@ def resolve(
after=after,
first=first,
last=last,
max_results=self.max_results,
)

async def resolve_async(
Expand Down Expand Up @@ -316,6 +320,7 @@ async def resolve_async(
after=after,
first=first,
last=last,
max_results=self.max_results,
)

# If nodes was an AsyncIterable/AsyncIterator, resolve_connection
Expand Down Expand Up @@ -357,6 +362,7 @@ def connection(
metadata: Optional[Mapping[Any, Any]] = None,
directives: Optional[Sequence[object]] = (),
extensions: list[FieldExtension] = (), # type: ignore
max_results: Optional[int] = None,
# This init parameter is used by pyright to determine whether this field
# is added in the constructor or not. It is not used to change
# any behaviour at the moment.
Expand Down Expand Up @@ -389,6 +395,9 @@ def connection(
metadata: The metadata of the field.
directives: The directives to apply to the field.
extensions: The extensions to apply to the field.
max_results: The maximum number of results this connection can return.
Can be set to override the default value of 100 defined in the
schema configuration.
init: Used only for type checking purposes.

Examples:
Expand Down Expand Up @@ -451,7 +460,7 @@ def get_some_nodes(self, age: int) -> Iterable[SomeType]: ...
default_factory=default_factory,
metadata=metadata,
directives=directives or (),
extensions=[*extensions, ConnectionExtension()],
extensions=[*extensions, ConnectionExtension(max_results=max_results)],
)
if resolver is not None:
f = f(resolver)
Expand Down
5 changes: 5 additions & 0 deletions strawberry/relay/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ def resolve_connection(
after: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
max_results: Optional[int] = None,
**kwargs: Any,
) -> AwaitableOrValue[Self]:
"""Resolve a connection from nodes.
Expand All @@ -731,6 +732,7 @@ def resolve_connection(
after: Returns the items in the list that come after the specified cursor.
first: Returns the first n items from the list.
last: Returns the items in the list that come after the specified cursor.
max_results: The maximum number of results to resolve.
kwargs: Additional arguments passed to the resolver.

Returns:
Expand Down Expand Up @@ -767,6 +769,7 @@ def resolve_connection( # noqa: PLR0915
after: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
max_results: Optional[int] = None,
**kwargs: Any,
) -> AwaitableOrValue[Self]:
"""Resolve a connection from the list of nodes.
Expand All @@ -780,6 +783,7 @@ def resolve_connection( # noqa: PLR0915
after: Returns the items in the list that come after the specified cursor.
first: Returns the first n items from the list.
last: Returns the items in the list that come after the specified cursor.
max_results: The maximum number of results to resolve.
kwargs: Additional arguments passed to the resolver.

Returns:
Expand All @@ -794,6 +798,7 @@ def resolve_connection( # noqa: PLR0915
after=after,
first=first,
last=last,
max_results=max_results,
)

type_def = get_object_definition(cls)
Expand Down
7 changes: 6 additions & 1 deletion strawberry/relay/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,16 @@ def from_arguments(
after: str | None = None,
first: int | None = None,
last: int | None = None,
max_results: int | None = None,
) -> Self:
"""Get the slice metadata to use on ListConnection."""
from strawberry.relay.types import PREFIX

max_results = info.schema.config.relay_max_results
max_results = (
max_results
if max_results is not None
else info.schema.config.relay_max_results
)
start = 0
end: int | None = None

Expand Down
58 changes: 58 additions & 0 deletions tests/relay/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import strawberry
from strawberry.permission import BasePermission
from strawberry.relay import Connection, Node
from strawberry.relay.types import ListConnection
from strawberry.schema.config import StrawberryConfig


@strawberry.type
Expand All @@ -34,6 +36,8 @@ def resolve_connection(
before: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
max_results: Optional[int] = None,
**kwargs: Any,
) -> Optional[Self]:
return None

Expand Down Expand Up @@ -124,3 +128,57 @@ def users(self) -> Optional[list[User]]: # pragma: no cover
result = schema.execute_sync(query)
assert result.data == {"users": None}
assert result.errors[0].message == "Not allowed"


@pytest.mark.parametrize(
("field_max_results", "schema_max_results", "results", "expected"),
[
(5, 100, 5, 5),
(5, 2, 5, 5),
(5, 100, 10, 5),
(5, 2, 10, 5),
(5, 100, 0, 0),
(5, 2, 0, 0),
(None, 100, 5, 5),
(None, 2, 5, 2),
bellini666 marked this conversation as resolved.
Show resolved Hide resolved
],
bellini666 marked this conversation as resolved.
Show resolved Hide resolved
)
def test_max_results(
field_max_results: Optional[int],
schema_max_results: int,
results: int,
expected: int,
):
@strawberry.type
class User(Node):
id: strawberry.relay.NodeID[str]

@strawberry.type
class Query:
@strawberry.relay.connection(
ListConnection[User],
max_results=field_max_results,
)
def users(self) -> list[User]:
return [User(id=str(i)) for i in range(results)]

schema = strawberry.Schema(
query=Query,
config=StrawberryConfig(relay_max_results=schema_max_results),
)
query = """
query {
users {
edges {
node {
id
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data is not None
assert isinstance(result.data["users"]["edges"], list)
assert len(result.data["users"]["edges"]) == expected
Loading