Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding contrib for GraphQL Relay #1214

Merged
merged 9 commits into from
Jan 31, 2025
27 changes: 12 additions & 15 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,25 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install wheel
pip install -e .[asgi-file-uploads,tracing,telemetry,test,dev]
- name: Install Hatch
uses: pypa/hatch@install
- name: Pytest
run: |
pytest --cov=ariadne --cov=tests
run: hatch test -c -py ${{ matrix.python-version }}
- uses: codecov/codecov-action@v3
- name: Linters
run: |
pylint ariadne tests
mypy ariadne tests_mypy --ignore-missing-imports --check-untyped-defs
black --check .
- name: Pylint
run: hatch run pylint --py-version=3.8 ariadne tests
- name: mypy
run: hatch run mypy ariadne tests_mypy --ignore-missing-imports --check-untyped-defs
- name: black
run: hatch run black -t py38 --check .
- name: Benchmark
run: |
pytest benchmark --benchmark-storage=file://benchmark/results --benchmark-compare
hatch run pytest benchmark --benchmark-storage=file://benchmark/results --benchmark-compare

integration:

Expand Down
20 changes: 20 additions & 0 deletions ariadne/contrib/relay/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from ariadne.contrib.relay.arguments import (
ConnectionArguments,
)
from ariadne.contrib.relay.connection import RelayConnection
from ariadne.contrib.relay.objects import (
RelayNodeInterfaceType,
RelayObjectType,
RelayQueryType,
)
from ariadne.contrib.relay.types import ConnectionResolver, GlobalIDTuple

__all__ = [
"ConnectionArguments",
"RelayNodeInterfaceType",
"RelayConnection",
"RelayObjectType",
"RelayQueryType",
"ConnectionResolver",
"GlobalIDTuple",
]
54 changes: 54 additions & 0 deletions ariadne/contrib/relay/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Optional, Type, Union

from typing_extensions import TypeAliasType


class ForwardConnectionArguments:
first: Optional[int]
after: Optional[str]

def __init__(
self, *, first: Optional[int] = None, after: Optional[str] = None
) -> None:
self.first = first
self.after = after


class BackwardConnectionArguments:
last: Optional[int]
before: Optional[str]

def __init__(
self, *, last: Optional[int] = None, before: Optional[str] = None
) -> None:
self.last = last
self.before = before


class ConnectionArguments:
def __init__(
self,
*,
first: Optional[int] = None,
after: Optional[str] = None,
last: Optional[int] = None,
before: Optional[str] = None,
) -> None:
self.first = first
self.after = after
self.last = last
self.before = before


ConnectionArgumentsUnion = TypeAliasType(
"ConnectionArgumentsUnion",
Union[ForwardConnectionArguments, BackwardConnectionArguments, ConnectionArguments],
)
ConnectionArgumentsTypeUnion = TypeAliasType(
"ConnectionArgumentsTypeUnion",
Union[
Type[ForwardConnectionArguments],
Type[BackwardConnectionArguments],
Type[ConnectionArguments],
],
)
35 changes: 35 additions & 0 deletions ariadne/contrib/relay/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Sequence

from typing_extensions import Any

from ariadne.contrib.relay.arguments import ConnectionArgumentsUnion


class RelayConnection:
def __init__(
self,
edges: Sequence[Any],
total: int,
has_next_page: bool,
has_previous_page: bool,
) -> None:
self.edges = edges
self.total = total
self.has_next_page = has_next_page
self.has_previous_page = has_previous_page

def get_cursor(self, node):
return node["id"]

def get_page_info(
self, connection_arguments: ConnectionArgumentsUnion
): # pylint: disable=unused-argument
return {
"hasNextPage": self.has_next_page,
"hasPreviousPage": self.has_previous_page,
"startCursor": self.get_cursor(self.edges[0]),
"endCursor": self.get_cursor(self.edges[-1]),
}

def get_edges(self):
return [{"node": node, "cursor": self.get_cursor(node)} for node in self.edges]
135 changes: 135 additions & 0 deletions ariadne/contrib/relay/objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from base64 import b64decode
from inspect import iscoroutinefunction
from typing import Optional, Tuple

from graphql.pyutils import is_awaitable
from graphql.type import GraphQLSchema

from ariadne import InterfaceType, ObjectType
from ariadne.contrib.relay.arguments import (
ConnectionArguments,
ConnectionArgumentsTypeUnion,
)
from ariadne.contrib.relay.types import (
ConnectionResolver,
GlobalIDDecoder,
GlobalIDTuple,
)
from ariadne.types import Resolver


def decode_global_id(kwargs) -> GlobalIDTuple:
return GlobalIDTuple(*b64decode(kwargs["id"]).decode().split(":"))


class RelayObjectType(ObjectType):
_node_resolver: Optional[Resolver] = None

def __init__(
self,
name: str,
connection_arguments_class: ConnectionArgumentsTypeUnion = ConnectionArguments,
) -> None:
super().__init__(name)
self.connection_arguments_class = connection_arguments_class

def resolve_wrapper(self, resolver: ConnectionResolver):
def wrapper(obj, info, *args, **kwargs):
connection_arguments = self.connection_arguments_class(**kwargs)
if iscoroutinefunction(resolver):

async def async_my_extension():
relay_connection = await resolver(
obj, info, connection_arguments, *args, **kwargs
)
if is_awaitable(relay_connection):
relay_connection = await relay_connection
return {
"edges": relay_connection.get_edges(),
"pageInfo": relay_connection.get_page_info(
connection_arguments
),
}

return async_my_extension()

relay_connection = resolver(
obj, info, connection_arguments, *args, **kwargs
)
return {
"edges": relay_connection.get_edges(),
"pageInfo": relay_connection.get_page_info(connection_arguments),
}

return wrapper

def connection(self, name: str):
def decorator(resolver: ConnectionResolver) -> ConnectionResolver:
self.set_field(name, self.resolve_wrapper(resolver))
return resolver

return decorator

def node_resolver(self, resolver: Resolver):
self._node_resolver = resolver
return resolver

def bind_to_schema(self, schema: GraphQLSchema) -> None:
super().bind_to_schema(schema)

if callable(self._node_resolver):
graphql_type = schema.type_map.get(self.name)
setattr(
graphql_type,
"__resolve_node__",
self._node_resolver,
)


class RelayNodeInterfaceType(InterfaceType):
def __init__(
self,
type_resolver: Optional[Resolver] = None,
) -> None:
super().__init__("Node", type_resolver)


class RelayQueryType(RelayObjectType):
def __init__(
self,
node: Optional[RelayNodeInterfaceType] = None,
global_id_decoder: GlobalIDDecoder = decode_global_id,
) -> None:
super().__init__("Query")
if node is None:
node = RelayNodeInterfaceType()
self.node = node
self.set_field("node", self.resolve_node)
self.global_id_decoder = global_id_decoder

@property
def bindables(self) -> Tuple["RelayQueryType", "RelayNodeInterfaceType"]:
return (self, self.node)

def get_node_resolver(self, type_name, schema: GraphQLSchema) -> Resolver:
type_object = schema.get_type(type_name)
try:
return getattr(type_object, "__resolve_node__")
except AttributeError as exc:
raise ValueError(f"No node resolver for type {type_name}") from exc

def resolve_node(self, obj, info, *args, **kwargs):
type_name, _ = self.global_id_decoder(kwargs)

resolver = self.get_node_resolver(type_name, info.schema)

if iscoroutinefunction(resolver):

async def async_my_extension():
result = await resolver(obj, info, *args, **kwargs)
if is_awaitable(result):
result = await result
return result

return async_my_extension()
return resolver(obj, info, *args, **kwargs)
10 changes: 10 additions & 0 deletions ariadne/contrib/relay/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from collections import namedtuple
from typing import Any, Callable, Dict

from typing_extensions import TypeVar

from ariadne.contrib.relay.connection import RelayConnection

ConnectionResolver = TypeVar("ConnectionResolver", bound=Callable[..., RelayConnection])
GlobalIDTuple = namedtuple("GlobalIDTuple", ["type", "id"])
GlobalIDDecoder = Callable[[Dict[str, Any]], GlobalIDTuple]
2 changes: 1 addition & 1 deletion ariadne/schema_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def _heal_field(field, _):
each(type_.fields, _heal_field)

def heal_type(
type_: Union[GraphQLList, GraphQLNamedType, GraphQLNonNull]
type_: Union[GraphQLList, GraphQLNamedType, GraphQLNonNull],
) -> Union[GraphQLList, GraphQLNamedType, GraphQLNonNull]:
# Unwrap the two known wrapper types
if isinstance(type_, GraphQLList):
Expand Down
2 changes: 1 addition & 1 deletion ariadne/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def gql(value: str) -> str:


def unwrap_graphql_error(
error: Union[GraphQLError, Optional[Exception]]
error: Union[GraphQLError, Optional[Exception]],
) -> Optional[Exception]:
"""Recursively unwrap exception when its instance of GraphQLError.

Expand Down
18 changes: 16 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
]

[project.optional-dependencies]
dev = ["black", "mypy", "pylint"]
dev = ["black<25", "mypy", "pylint"]
test = [
"pytest",
"pytest-asyncio",
Expand Down Expand Up @@ -73,10 +73,23 @@ features = ["dev", "test"]

[tool.hatch.envs.default.scripts]
test = "coverage run -m pytest"
check = [
"pylint --py-version=3.8 ariadne tests",
"mypy ariadne tests_mypy --ignore-missing-imports --check-untyped-defs",
"black --check .",
"hatch test -a -p",
"hatch test --cover",
]

[tool.hatch.envs.hatch-test]
features = ["dev", "test"]

[[tool.hatch.envs.hatch-test.matrix]]
python = ["3.8", "3.9", "3.10", "3.11", "3.12"]

[tool.black]
line-length = 88
target-version = ['py36', 'py37', 'py38']
target-version = ['py38']
include = '\.pyi?$'
exclude = '''
/(
Expand All @@ -96,6 +109,7 @@ exclude = '''

[tool.pytest.ini_options]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]

[tool.coverage.run]
Expand Down
Empty file added tests/relay/__init__.py
Empty file.
Loading