diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1612af43..5d455868 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -18,28 +18,25 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install wheel - pip install -e .[asgi-file-uploads,tracing,telemetry,test,dev] + - name: Install Hatch + uses: pypa/hatch@install - name: Pytest - run: | - pytest --cov=ariadne --cov=tests + run: hatch test -c -py ${{ matrix.python-version }} - uses: codecov/codecov-action@v3 - - name: Linters - run: | - pylint ariadne tests - mypy ariadne tests_mypy --ignore-missing-imports --check-untyped-defs - black --check . + - name: Pylint + run: hatch run pylint --py-version=3.8 ariadne tests + - name: mypy + run: hatch run mypy ariadne tests_mypy --ignore-missing-imports --check-untyped-defs + - name: black + run: hatch run black -t py38 --check . - name: Benchmark run: | - pytest benchmark --benchmark-storage=file://benchmark/results --benchmark-compare + hatch run pytest benchmark --benchmark-storage=file://benchmark/results --benchmark-compare integration: diff --git a/ariadne/contrib/relay/__init__.py b/ariadne/contrib/relay/__init__.py new file mode 100644 index 00000000..07b04690 --- /dev/null +++ b/ariadne/contrib/relay/__init__.py @@ -0,0 +1,20 @@ +from ariadne.contrib.relay.arguments import ( + ConnectionArguments, +) +from ariadne.contrib.relay.connection import RelayConnection +from ariadne.contrib.relay.objects import ( + RelayNodeInterfaceType, + RelayObjectType, + RelayQueryType, +) +from ariadne.contrib.relay.types import ConnectionResolver, GlobalIDTuple + +__all__ = [ + "ConnectionArguments", + "RelayNodeInterfaceType", + "RelayConnection", + "RelayObjectType", + "RelayQueryType", + "ConnectionResolver", + "GlobalIDTuple", +] diff --git a/ariadne/contrib/relay/arguments.py b/ariadne/contrib/relay/arguments.py new file mode 100644 index 00000000..040b1fe7 --- /dev/null +++ b/ariadne/contrib/relay/arguments.py @@ -0,0 +1,54 @@ +from typing import Optional, Type, Union + +from typing_extensions import TypeAliasType + + +class ForwardConnectionArguments: + first: Optional[int] + after: Optional[str] + + def __init__( + self, *, first: Optional[int] = None, after: Optional[str] = None + ) -> None: + self.first = first + self.after = after + + +class BackwardConnectionArguments: + last: Optional[int] + before: Optional[str] + + def __init__( + self, *, last: Optional[int] = None, before: Optional[str] = None + ) -> None: + self.last = last + self.before = before + + +class ConnectionArguments: + def __init__( + self, + *, + first: Optional[int] = None, + after: Optional[str] = None, + last: Optional[int] = None, + before: Optional[str] = None, + ) -> None: + self.first = first + self.after = after + self.last = last + self.before = before + + +ConnectionArgumentsUnion = TypeAliasType( + "ConnectionArgumentsUnion", + Union[ForwardConnectionArguments, BackwardConnectionArguments, ConnectionArguments], +) +ConnectionArgumentsTypeUnion = TypeAliasType( + "ConnectionArgumentsTypeUnion", + Union[ + Type[ForwardConnectionArguments], + Type[BackwardConnectionArguments], + Type[ConnectionArguments], + ], +) diff --git a/ariadne/contrib/relay/connection.py b/ariadne/contrib/relay/connection.py new file mode 100644 index 00000000..889cd37c --- /dev/null +++ b/ariadne/contrib/relay/connection.py @@ -0,0 +1,35 @@ +from typing import Sequence + +from typing_extensions import Any + +from ariadne.contrib.relay.arguments import ConnectionArgumentsUnion + + +class RelayConnection: + def __init__( + self, + edges: Sequence[Any], + total: int, + has_next_page: bool, + has_previous_page: bool, + ) -> None: + self.edges = edges + self.total = total + self.has_next_page = has_next_page + self.has_previous_page = has_previous_page + + def get_cursor(self, node): + return node["id"] + + def get_page_info( + self, connection_arguments: ConnectionArgumentsUnion + ): # pylint: disable=unused-argument + return { + "hasNextPage": self.has_next_page, + "hasPreviousPage": self.has_previous_page, + "startCursor": self.get_cursor(self.edges[0]), + "endCursor": self.get_cursor(self.edges[-1]), + } + + def get_edges(self): + return [{"node": node, "cursor": self.get_cursor(node)} for node in self.edges] diff --git a/ariadne/contrib/relay/objects.py b/ariadne/contrib/relay/objects.py new file mode 100644 index 00000000..00c46cb2 --- /dev/null +++ b/ariadne/contrib/relay/objects.py @@ -0,0 +1,135 @@ +from base64 import b64decode +from inspect import iscoroutinefunction +from typing import Optional, Tuple + +from graphql.pyutils import is_awaitable +from graphql.type import GraphQLSchema + +from ariadne import InterfaceType, ObjectType +from ariadne.contrib.relay.arguments import ( + ConnectionArguments, + ConnectionArgumentsTypeUnion, +) +from ariadne.contrib.relay.types import ( + ConnectionResolver, + GlobalIDDecoder, + GlobalIDTuple, +) +from ariadne.types import Resolver + + +def decode_global_id(kwargs) -> GlobalIDTuple: + return GlobalIDTuple(*b64decode(kwargs["id"]).decode().split(":")) + + +class RelayObjectType(ObjectType): + _node_resolver: Optional[Resolver] = None + + def __init__( + self, + name: str, + connection_arguments_class: ConnectionArgumentsTypeUnion = ConnectionArguments, + ) -> None: + super().__init__(name) + self.connection_arguments_class = connection_arguments_class + + def resolve_wrapper(self, resolver: ConnectionResolver): + def wrapper(obj, info, *args, **kwargs): + connection_arguments = self.connection_arguments_class(**kwargs) + if iscoroutinefunction(resolver): + + async def async_my_extension(): + relay_connection = await resolver( + obj, info, connection_arguments, *args, **kwargs + ) + if is_awaitable(relay_connection): + relay_connection = await relay_connection + return { + "edges": relay_connection.get_edges(), + "pageInfo": relay_connection.get_page_info( + connection_arguments + ), + } + + return async_my_extension() + + relay_connection = resolver( + obj, info, connection_arguments, *args, **kwargs + ) + return { + "edges": relay_connection.get_edges(), + "pageInfo": relay_connection.get_page_info(connection_arguments), + } + + return wrapper + + def connection(self, name: str): + def decorator(resolver: ConnectionResolver) -> ConnectionResolver: + self.set_field(name, self.resolve_wrapper(resolver)) + return resolver + + return decorator + + def node_resolver(self, resolver: Resolver): + self._node_resolver = resolver + return resolver + + def bind_to_schema(self, schema: GraphQLSchema) -> None: + super().bind_to_schema(schema) + + if callable(self._node_resolver): + graphql_type = schema.type_map.get(self.name) + setattr( + graphql_type, + "__resolve_node__", + self._node_resolver, + ) + + +class RelayNodeInterfaceType(InterfaceType): + def __init__( + self, + type_resolver: Optional[Resolver] = None, + ) -> None: + super().__init__("Node", type_resolver) + + +class RelayQueryType(RelayObjectType): + def __init__( + self, + node: Optional[RelayNodeInterfaceType] = None, + global_id_decoder: GlobalIDDecoder = decode_global_id, + ) -> None: + super().__init__("Query") + if node is None: + node = RelayNodeInterfaceType() + self.node = node + self.set_field("node", self.resolve_node) + self.global_id_decoder = global_id_decoder + + @property + def bindables(self) -> Tuple["RelayQueryType", "RelayNodeInterfaceType"]: + return (self, self.node) + + def get_node_resolver(self, type_name, schema: GraphQLSchema) -> Resolver: + type_object = schema.get_type(type_name) + try: + return getattr(type_object, "__resolve_node__") + except AttributeError as exc: + raise ValueError(f"No node resolver for type {type_name}") from exc + + def resolve_node(self, obj, info, *args, **kwargs): + type_name, _ = self.global_id_decoder(kwargs) + + resolver = self.get_node_resolver(type_name, info.schema) + + if iscoroutinefunction(resolver): + + async def async_my_extension(): + result = await resolver(obj, info, *args, **kwargs) + if is_awaitable(result): + result = await result + return result + + return async_my_extension() + return resolver(obj, info, *args, **kwargs) diff --git a/ariadne/contrib/relay/types.py b/ariadne/contrib/relay/types.py new file mode 100644 index 00000000..8a1c3aad --- /dev/null +++ b/ariadne/contrib/relay/types.py @@ -0,0 +1,10 @@ +from collections import namedtuple +from typing import Any, Callable, Dict + +from typing_extensions import TypeVar + +from ariadne.contrib.relay.connection import RelayConnection + +ConnectionResolver = TypeVar("ConnectionResolver", bound=Callable[..., RelayConnection]) +GlobalIDTuple = namedtuple("GlobalIDTuple", ["type", "id"]) +GlobalIDDecoder = Callable[[Dict[str, Any]], GlobalIDTuple] diff --git a/ariadne/schema_visitor.py b/ariadne/schema_visitor.py index 7d76d7dc..dc0de650 100644 --- a/ariadne/schema_visitor.py +++ b/ariadne/schema_visitor.py @@ -720,7 +720,7 @@ def _heal_field(field, _): each(type_.fields, _heal_field) def heal_type( - type_: Union[GraphQLList, GraphQLNamedType, GraphQLNonNull] + type_: Union[GraphQLList, GraphQLNamedType, GraphQLNonNull], ) -> Union[GraphQLList, GraphQLNamedType, GraphQLNonNull]: # Unwrap the two known wrapper types if isinstance(type_, GraphQLList): diff --git a/ariadne/utils.py b/ariadne/utils.py index 81b267fa..3d0a326b 100644 --- a/ariadne/utils.py +++ b/ariadne/utils.py @@ -118,7 +118,7 @@ def gql(value: str) -> str: def unwrap_graphql_error( - error: Union[GraphQLError, Optional[Exception]] + error: Union[GraphQLError, Optional[Exception]], ) -> Optional[Exception]: """Recursively unwrap exception when its instance of GraphQLError. diff --git a/pyproject.toml b/pyproject.toml index 66b6f32d..5587d8d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["black", "mypy", "pylint"] +dev = ["black<25", "mypy", "pylint"] test = [ "pytest", "pytest-asyncio", @@ -73,10 +73,23 @@ features = ["dev", "test"] [tool.hatch.envs.default.scripts] test = "coverage run -m pytest" +check = [ + "pylint --py-version=3.8 ariadne tests", + "mypy ariadne tests_mypy --ignore-missing-imports --check-untyped-defs", + "black --check .", + "hatch test -a -p", + "hatch test --cover", +] + +[tool.hatch.envs.hatch-test] +features = ["dev", "test"] + +[[tool.hatch.envs.hatch-test.matrix]] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.black] line-length = 88 -target-version = ['py36', 'py37', 'py38'] +target-version = ['py38'] include = '\.pyi?$' exclude = ''' /( @@ -96,6 +109,7 @@ exclude = ''' [tool.pytest.ini_options] asyncio_mode = "strict" +asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] [tool.coverage.run] diff --git a/tests/relay/__init__.py b/tests/relay/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/relay/conftest.py b/tests/relay/conftest.py new file mode 100644 index 00000000..17549b3e --- /dev/null +++ b/tests/relay/conftest.py @@ -0,0 +1,201 @@ +from base64 import b64decode + +import pytest +from ariadne.contrib.relay import ( + ConnectionArguments, + GlobalIDTuple, + RelayConnection, + RelayNodeInterfaceType, + RelayObjectType, + RelayQueryType, +) + + +@pytest.fixture +def relay_type_defs(): + return """\ +interface Node { + bid: ID! +} + +type Faction implements Node { + bid: ID! + name: String + ships(first: Int!, after: ID): ShipConnection +} + +type Ship implements Node { + bid: ID! + name: String +} + +type ShipConnection { + edges: [ShipEdge] + pageInfo: PageInfo! + ships: [Ship] + totalCount: Int +} + +type ShipEdge { + cursor: String! + node: Ship +} + +type PageInfo { + hasNextPage: Boolean! + hasPreviousPage: Boolean! + startCursor: String + endCursor: String +} + +type Query { + rebels: Faction + empire: Faction + node(bid: ID!): Node +} +""" + + +@pytest.fixture +def global_id_decoder(): + return lambda kwargs: GlobalIDTuple(*b64decode(kwargs["bid"]).decode().split(":")) + + +@pytest.fixture +def relay_node_interface(): + return RelayNodeInterfaceType() + + +@pytest.fixture +def relay_query(factions, relay_node_interface, global_id_decoder): + query = RelayQueryType( + node=relay_node_interface, + global_id_decoder=global_id_decoder, + ) + query.set_field("rebels", lambda *_: factions[0]) + query.set_field("empire", lambda *_: factions[1]) + query.node.set_field("bid", lambda obj, *_: obj["id"]) + return query + + +@pytest.fixture +def ships(): + return [ + { + "id": "U2hpcDox", + "name": "X-Wing", + "factionId": "RmFjdGlvbjox", + }, + { + "id": "U2hpcDoy", + "name": "Y-Wing", + "factionId": "RmFjdGlvbjox", + }, + { + "id": "U2hpcDoz", + "name": "A-Wing", + "factionId": "RmFjdGlvbjox", + }, + { + "id": "U2hpcDo0", + "name": "Millennium Falcon", + "factionId": "RmFjdGlvbjox", + }, + { + "id": "U2hpcDo1", + "name": "Home One", + "factionId": "RmFjdGlvbjox", + }, + { + "id": "U2hpcDo2", + "name": "TIE Fighter", + "factionId": "RmFjdGlvbjoy", + }, + { + "id": "U2hpcDo3", + "name": "TIE Bomber", + "factionId": "RmFjdGlvbjoy", + }, + { + "id": "U2hpcDo4", + "name": "TIE Interceptor", + "factionId": "RmFjdGlvbjoy", + }, + { + "id": "U2hpcDo5", + "name": "Darth Vader's TIE Advanced", + "factionId": "RmFjdGlvbjoy", + }, + ] + + +@pytest.fixture +def factions(): + return [ + { + "id": "RmFjdGlvbjox", + "name": "Alliance to Restore the Republic", + }, + {"id": "RmFjdGlvbjoy", "name": "Galactic Empire"}, + ] + + +@pytest.fixture +def relay_faction_object(factions): + faction = RelayObjectType("Faction") + faction.node_resolver( + lambda *_, bid: [ + {"__typename": "Faction", **faction} + for faction in factions + if faction["id"] == bid + ][0] + ) + return faction + + +@pytest.fixture +def relay_ship_object(ships): + ship = RelayObjectType("Ship") + ship.node_resolver( + lambda *_, bid: [ + {"__typename": "Ship", **ship} for ship in ships if ship["id"] == bid + ][0] + ) + return ship + + +@pytest.fixture +def ship_slice_resolver(ships): + # pylint: disable=unused-argument + def resolver( + faction_obj, info, connection_arguments: ConnectionArguments, **kwargs + ): + faction_ships = [ + ship for ship in ships if ship["factionId"] == faction_obj["id"] + ] + total = len(faction_ships) + if connection_arguments.after: + after_index = ( + faction_ships.index( + next( + ship + for ship in faction_ships + if ship["id"] == connection_arguments.after + ) + ) + + 1 + ) + else: + after_index = 0 + ships_slice = faction_ships[ + after_index : after_index + connection_arguments.first + ] + + return RelayConnection( + edges=ships_slice, + total=total, + has_next_page=after_index + connection_arguments.first < total, + has_previous_page=after_index > 0, + ) + + return resolver diff --git a/tests/relay/test_arguments.py b/tests/relay/test_arguments.py new file mode 100644 index 00000000..d5befe98 --- /dev/null +++ b/tests/relay/test_arguments.py @@ -0,0 +1,27 @@ +from ariadne.contrib.relay.arguments import ( + BackwardConnectionArguments, + ConnectionArguments, + ForwardConnectionArguments, +) + + +def test_connection_arguments(): + connection_arguments = ConnectionArguments( + first=10, after="cursor", last=5, before="cursor" + ) + assert connection_arguments.first == 10 + assert connection_arguments.after == "cursor" + assert connection_arguments.last == 5 + assert connection_arguments.before == "cursor" + + +def test_forward_connection_arguments(): + connection_arguments = ForwardConnectionArguments(first=10, after="cursor") + assert connection_arguments.first == 10 + assert connection_arguments.after == "cursor" + + +def test_backward_connection_arguments(): + connection_arguments = BackwardConnectionArguments(last=5, before="cursor") + assert connection_arguments.last == 5 + assert connection_arguments.before == "cursor" diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py new file mode 100644 index 00000000..eec3040f --- /dev/null +++ b/tests/relay/test_connection.py @@ -0,0 +1,129 @@ +from graphql import graphql_sync + +from ariadne import make_executable_schema +from ariadne.contrib.relay import RelayConnection, RelayObjectType + + +def test_relay_connection(): + connection = RelayConnection( + edges=[{"id": "VXNlcjox", "name": "Alice"}, {"id": "VXNlcjoy", "name": "Bob"}], + total=2, + has_next_page=False, + has_previous_page=False, + ) + assert connection.total == 2 + assert connection.has_next_page is False + assert connection.has_previous_page is False + assert connection.get_cursor({"id": "VXNlcjox"}) == "VXNlcjox" + assert connection.get_page_info({}) == { + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": "VXNlcjox", + "endCursor": "VXNlcjoy", + } + assert connection.get_edges() == [ + {"node": {"id": "VXNlcjox", "name": "Alice"}, "cursor": "VXNlcjox"}, + {"node": {"id": "VXNlcjoy", "name": "Bob"}, "cursor": "VXNlcjoy"}, + ] + + +CONNECTION_QUERY = """\ +query GetShips { + rebels{ + bid + name + ships(first: 6) { + ...ShipConnectionFragment + } + moreShips: ships(first: 2, after: "U2hpcDoy") { + ...ShipConnectionFragment + } + } +} + + +fragment ShipConnectionFragment on ShipConnection { + pageInfo { + hasNextPage + hasPreviousPage + startCursor + endCursor + } + edges { + cursor + node { + bid + name + } + } +} +""" + + +def test_relay_query_with_connection(relay_type_defs, relay_query, ship_slice_resolver): + faction = RelayObjectType("Faction") + + faction.connection("ships")(ship_slice_resolver) + + schema = make_executable_schema( + relay_type_defs, + *relay_query.bindables, + faction, + ) + result = graphql_sync(schema, CONNECTION_QUERY) + + assert result.errors is None + assert result.data == { + "rebels": { + "bid": "RmFjdGlvbjox", + "name": "Alliance to Restore the Republic", + "ships": { + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": "U2hpcDox", + "endCursor": "U2hpcDo1", + }, + "edges": [ + { + "cursor": "U2hpcDox", + "node": {"bid": "U2hpcDox", "name": "X-Wing"}, + }, + { + "cursor": "U2hpcDoy", + "node": {"bid": "U2hpcDoy", "name": "Y-Wing"}, + }, + { + "cursor": "U2hpcDoz", + "node": {"bid": "U2hpcDoz", "name": "A-Wing"}, + }, + { + "cursor": "U2hpcDo0", + "node": {"bid": "U2hpcDo0", "name": "Millennium Falcon"}, + }, + { + "cursor": "U2hpcDo1", + "node": {"bid": "U2hpcDo1", "name": "Home One"}, + }, + ], + }, + "moreShips": { + "pageInfo": { + "hasNextPage": True, + "hasPreviousPage": True, + "startCursor": "U2hpcDoz", + "endCursor": "U2hpcDo0", + }, + "edges": [ + { + "cursor": "U2hpcDoz", + "node": {"bid": "U2hpcDoz", "name": "A-Wing"}, + }, + { + "cursor": "U2hpcDo0", + "node": {"bid": "U2hpcDo0", "name": "Millennium Falcon"}, + }, + ], + }, + } + } diff --git a/tests/relay/test_objects.py b/tests/relay/test_objects.py new file mode 100644 index 00000000..2e75ef52 --- /dev/null +++ b/tests/relay/test_objects.py @@ -0,0 +1,239 @@ +import pytest +from graphql import graphql_sync +from pytest_mock import MockFixture + +from ariadne import make_executable_schema +from ariadne.contrib.relay.arguments import ( + ConnectionArguments, + ForwardConnectionArguments, +) +from ariadne.contrib.relay.connection import RelayConnection +from ariadne.contrib.relay.objects import ( + RelayObjectType, + RelayQueryType, + decode_global_id, +) +from ariadne.contrib.relay.types import ( + GlobalIDTuple, +) + + +@pytest.fixture +def friends_connection(): + return RelayConnection( + edges=[{"id": "VXNlcjox", "name": "Alice"}, {"id": "VXNlcjoy", "name": "Bob"}], + total=2, + has_next_page=False, + has_previous_page=False, + ) + + +def test_decode_global_id(): + assert decode_global_id({"id": "VXNlcjox"}) == GlobalIDTuple("User", "1") + + +def test_default_id_decoder(): + query = RelayQueryType() + assert query.global_id_decoder is decode_global_id + + +def test_missing_node_resolver(relay_type_defs, relay_query): + schema = make_executable_schema(relay_type_defs, *relay_query.bindables) + with pytest.raises(ValueError): + relay_query.get_node_resolver("NonExistingType", schema) + + +def test_node_resolver_storage(relay_type_defs, relay_query: RelayQueryType): + ship = RelayObjectType("Ship") + + def resolve_ship(*_): + pass + + ship.node_resolver(resolve_ship) + + schema = make_executable_schema(relay_type_defs, *relay_query.bindables, ship) + + assert relay_query.get_node_resolver("Ship", schema) is resolve_ship + + +def test_query_type_node_field_resolver(): + # pylint: disable=protected-access,comparison-with-callable + + query = RelayQueryType() + assert query._resolvers["node"] == query.resolve_node + + +def test_query_type_bindables(): + query = RelayQueryType() + assert query.bindables == (query, query.node) + + +def test_query_type_default_resolve_node(mocker: MockFixture, relay_type_defs): + query = RelayQueryType() + mock_resolver = mocker.Mock() + ship = RelayObjectType("Ship") + ship.node_resolver(mock_resolver) + schema = make_executable_schema(relay_type_defs, query, ship) + mock_info = mocker.Mock(schema=schema) + + assert ( + query.resolve_node(None, mock_info, id="U2hpcDox") == mock_resolver.return_value + ) + mock_resolver.assert_called_once_with(None, mock_info, id="U2hpcDox") + + +@pytest.mark.asyncio +async def test_query_type_default_async_resolve_node( + mocker: MockFixture, relay_type_defs +): + query = RelayQueryType() + ship = RelayObjectType("Ship") + mock_async_resolver = mocker.AsyncMock() + ship.node_resolver(mock_async_resolver) + schema = make_executable_schema(relay_type_defs, query, ship) + mock_info = mocker.Mock(schema=schema) + + awaitable_resolver = query.resolve_node(None, mock_info, id="U2hpcDox") + await awaitable_resolver + mock_async_resolver.assert_awaited_once_with(None, mock_info, id="U2hpcDox") + + +def test_relay_object_type(): + object_type = RelayObjectType("User") + assert object_type.connection_arguments_class == ConnectionArguments + + +def test_relay_object_resolve_wrapper(mocker: MockFixture, friends_connection): + mock_resolver = mocker.Mock(return_value=friends_connection) + mock_connection_arguments = mocker.Mock() + mock_connection_arguments_class = mocker.Mock( + return_value=mock_connection_arguments + ) + + object_type = RelayObjectType( + "User", connection_arguments_class=mock_connection_arguments_class + ) + wrapped_resolver = object_type.resolve_wrapper(mock_resolver) + + result = wrapped_resolver(None, None, first=10) + assert result == { + "edges": [ + {"node": {"id": "VXNlcjox", "name": "Alice"}, "cursor": "VXNlcjox"}, + {"node": {"id": "VXNlcjoy", "name": "Bob"}, "cursor": "VXNlcjoy"}, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": "VXNlcjox", + "endCursor": "VXNlcjoy", + }, + } + + mock_resolver.assert_called_once_with( + None, None, mock_connection_arguments, first=10 + ) + mock_connection_arguments_class.assert_called_once_with(first=10) + + +@pytest.mark.asyncio +async def test_relay_object_resolve_wrapper_async(friends_connection): + async def resolver(*_, **__): + return friends_connection + + object_type = RelayObjectType("User") + wrapped_resolver = object_type.resolve_wrapper(resolver) + + result = await wrapped_resolver(None, None, first=10) + assert result == { + "edges": [ + {"node": {"id": "VXNlcjox", "name": "Alice"}, "cursor": "VXNlcjox"}, + {"node": {"id": "VXNlcjoy", "name": "Bob"}, "cursor": "VXNlcjoy"}, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": "VXNlcjox", + "endCursor": "VXNlcjoy", + }, + } + + +def test_relay_object_resolve_wrapper_with_custom_arguments(mocker: MockFixture): + object_type = RelayObjectType( + "User", connection_arguments_class=ForwardConnectionArguments + ) + mock_resolver = mocker.Mock() + + wrapped_resolver = object_type.resolve_wrapper(mock_resolver) + wrapped_resolver(None, None, first=10, after="VXNlcjox") + + connection_arg_call = mock_resolver.call_args_list[0].args[2] + + assert connection_arg_call.first == 10 + assert connection_arg_call.after == "VXNlcjox" + + +def test_relay_object_connection_decorator(mocker: MockFixture): + # pylint: disable=protected-access + object_type = RelayObjectType("User") + mock_resolve_wrapper = mocker.patch.object(object_type, "resolve_wrapper") + + @object_type.connection("friends") + def resolve_friends(*_): + pass + + mock_resolve_wrapper.assert_called_once_with(resolve_friends) + + assert object_type._resolvers["friends"] == mock_resolve_wrapper.return_value + + +def test_relay_query( + relay_type_defs, + relay_query, +): + schema = make_executable_schema( + relay_type_defs, + *relay_query.bindables, + ) + result = graphql_sync(schema, "{ rebels { bid name } }") + + assert result.errors is None + assert result.data == { + "rebels": {"bid": "RmFjdGlvbjox", "name": "Alliance to Restore the Republic"} + } + + +def test_relay_node_query_ship( + relay_type_defs, + relay_query, + relay_ship_object, +): + schema = make_executable_schema( + relay_type_defs, + *relay_query.bindables, + relay_ship_object, + ) + result = graphql_sync( + schema, '{ node(bid: "U2hpcDoz") { ... on Ship { bid name } } }' + ) + + assert result.errors is None + assert result.data == {"node": {"bid": "U2hpcDoz", "name": "A-Wing"}} + + +def test_relay_node_query_faction( + relay_type_defs, + relay_query, + relay_faction_object, +): + schema = make_executable_schema( + relay_type_defs, + *relay_query.bindables, + relay_faction_object, + ) + result = graphql_sync( + schema, '{ node(bid: "RmFjdGlvbjoy") { ... on Faction { bid name } } }' + ) + + assert result.errors is None + assert result.data == {"node": {"bid": "RmFjdGlvbjoy", "name": "Galactic Empire"}}