From a721659e8e0a45c820409648ee23508d67b604c1 Mon Sep 17 00:00:00 2001 From: Take Weiland Date: Sun, 22 Dec 2024 21:00:57 +0100 Subject: [PATCH 1/2] Add new ordering method allowing ordering by multiple fields --- strawberry_django/fields/field.py | 5 + strawberry_django/ordering.py | 169 +++++++++++- strawberry_django/settings.py | 4 + strawberry_django/type.py | 5 + tests/test_legacy_order.py | 444 ++++++++++++++++++++++++++++++ tests/test_ordering.py | 262 +++--------------- tests/test_settings.py | 2 + tests/types.py | 1 + 8 files changed, 671 insertions(+), 221 deletions(-) create mode 100644 tests/test_legacy_order.py diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index f86c35c5..3dfc9b8d 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -548,6 +548,7 @@ def field( pagination: bool | UnsetType = UNSET, filters: type | UnsetType | None = UNSET, order: type | UnsetType | None = UNSET, + ordering: type | UnsetType | None = UNSET, only: TypeOrSequence[str] | None = None, select_related: TypeOrSequence[str] | None = None, prefetch_related: TypeOrSequence[PrefetchType] | None = None, @@ -576,6 +577,7 @@ def field( pagination: bool | UnsetType = UNSET, filters: type | UnsetType | None = UNSET, order: type | UnsetType | None = UNSET, + ordering: type | UnsetType | None = UNSET, only: TypeOrSequence[str] | None = None, select_related: TypeOrSequence[str] | None = None, prefetch_related: TypeOrSequence[PrefetchType] | None = None, @@ -604,6 +606,7 @@ def field( pagination: bool | UnsetType = UNSET, filters: type | UnsetType | None = UNSET, order: type | UnsetType | None = UNSET, + ordering: type | UnsetType | None = UNSET, only: TypeOrSequence[str] | None = None, select_related: TypeOrSequence[str] | None = None, prefetch_related: TypeOrSequence[PrefetchType] | None = None, @@ -631,6 +634,7 @@ def field( pagination: bool | UnsetType = UNSET, filters: type | UnsetType | None = UNSET, order: type | UnsetType | None = UNSET, + ordering: type | UnsetType | None = UNSET, only: TypeOrSequence[str] | None = None, select_related: TypeOrSequence[str] | None = None, prefetch_related: TypeOrSequence[PrefetchType] | None = None, @@ -672,6 +676,7 @@ def field( filters=filters, pagination=pagination, order=order, + ordering=ordering, extensions=extensions, only=only, select_related=select_related, diff --git a/strawberry_django/ordering.py b/strawberry_django/ordering.py index 86d0bac0..5602bec6 100644 --- a/strawberry_django/ordering.py +++ b/strawberry_django/ordering.py @@ -2,20 +2,24 @@ import dataclasses import enum +import functools from typing import ( TYPE_CHECKING, Callable, Optional, TypeVar, cast, + Any, + Mapping, ) import strawberry from django.db.models import F, OrderBy, QuerySet +from graphql import VariableNode from graphql.language.ast import ObjectValueNode from strawberry import UNSET from strawberry.types import has_object_definition -from strawberry.types.base import WithStrawberryObjectDefinition +from strawberry.types.base import WithStrawberryObjectDefinition, StrawberryOptional from strawberry.types.field import StrawberryField, field from strawberry.types.unset import UnsetType from strawberry.utils.str_converters import to_camel_case @@ -30,6 +34,7 @@ from strawberry_django.utils.typing import is_auto from .arguments import argument +from .settings import strawberry_django_settings if TYPE_CHECKING: from collections.abc import Collection, Sequence @@ -194,17 +199,103 @@ def parse_and_fill(field: ObjectValueNode, seq: dict[str, OrderSequence]): return queryset.order_by(*args) +def process_ordering_default( + ordering: Collection[WithStrawberryObjectDefinition] | None, + info: Info | None, + queryset: _QS, + prefix: str = "", +) -> tuple[_QS, Collection[F | OrderBy | str]]: + args = [] + + for o in ordering: + for f in o.__strawberry_definition__.fields: + f_value = getattr(o, f.name, UNSET) + if f_value is UNSET or ( + f_value is None and not f.metadata.get(WITH_NONE_META) + ): + continue + + if isinstance(f, FilterOrderField) and f.base_resolver: + res = f.base_resolver( + o, + info, + value=f_value, + queryset=queryset, + prefix=prefix, + ) + if isinstance(res, tuple): + queryset, subargs = res + else: + subargs = res + args.extend(subargs) + elif isinstance(f_value, Ordering): + args.append(f_value.resolve(f"{prefix}{f.name}")) + else: + ordering_cls = f.type + if isinstance(ordering_cls, StrawberryOptional): + ordering_cls = ordering_cls.of_type + assert isinstance(ordering_cls, type) and has_object_definition( + ordering_cls + ) + queryset, subargs = process_ordering( + ordering_cls, + (f_value,), + info, + queryset, + prefix=f"{prefix}{f.name}__", + ) + args.extend(subargs) + + return queryset, args + + +def process_ordering( + ordering_cls: type[WithStrawberryObjectDefinition], + ordering: Collection[WithStrawberryObjectDefinition] | None, + info: Info | None, + queryset: _QS, + prefix: str = "", +) -> tuple[_QS, Collection[F | OrderBy | str]]: + if callable( + order_method := getattr(ordering_cls, "process_ordering", None), + ): + return order_method(order, info, queryset=queryset, prefix=prefix) + else: + return process_ordering_default(ordering, info, queryset, prefix) + + +def apply_ordering( + ordering_cls: type[WithStrawberryObjectDefinition], + ordering: Collection[WithStrawberryObjectDefinition] | None, + info: Info | None, + queryset: _QS, +) -> _QS: + queryset, args = process_ordering(ordering_cls, ordering, info, queryset) + if args: + queryset = queryset.order_by(*args) + return queryset + + class StrawberryDjangoFieldOrdering(StrawberryDjangoFieldBase): - def __init__(self, order: type | UnsetType | None = UNSET, **kwargs): + def __init__( + self, + order: type | UnsetType | None = UNSET, + ordering: type | UnsetType | None = UNSET, + **kwargs, + ): if order and not has_object_definition(order): raise TypeError("order needs to be a strawberry type") + if ordering and not has_object_definition(ordering): + raise TypeError("ordering needs to be a strawberry type") self.order = order + self.ordering = ordering super().__init__(**kwargs) def __copy__(self) -> Self: new_field = super().__copy__() new_field.order = self.order + new_field.ordering = self.ordering return new_field @property @@ -214,6 +305,12 @@ def arguments(self) -> list[StrawberryArgument]: order = self.get_order() if order and order is not UNSET: arguments.append(argument("order", order, is_optional=True)) + if self.base_resolver is None: + ordering = self.get_ordering() + if ordering is not None: + arguments.append( + argument("ordering", ordering, is_list=True, default=[]) + ) return super().arguments + arguments @arguments.setter @@ -236,16 +333,82 @@ def get_order(self) -> type[WithStrawberryObjectDefinition] | None: return order if order is not UNSET else None + def get_ordering(self) -> type[WithStrawberryObjectDefinition] | None: + ordering = self.ordering + if ordering is None: + return None + + if ordering is UNSET: + django_type = self.django_type + ordering = ( + django_type.__strawberry_django_definition__.ordering + if django_type not in (None, UNSET) + else None + ) + + return ordering + def get_queryset( self, queryset: _QS, info: Info, *, order: WithStrawberryObjectDefinition | None = None, + ordering: list[WithStrawberryObjectDefinition] | None = None, **kwargs, ) -> _QS: queryset = super().get_queryset(queryset, info, **kwargs) - return apply(order, queryset, info=info) + queryset = apply(order, queryset, info=info) + if ordering_cls := self.get_ordering(): + queryset = apply_ordering(ordering_cls, ordering, info, queryset) + return queryset + + +@dataclass_transform( + order_default=True, + field_specifiers=( + StrawberryField, + field, + ), +) +def ordering( + model: type[Model], + *, + name: str | None = None, + one_of: bool | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), +) -> Callable[[_T], _T]: + def wrapper(cls): + nonlocal one_of + try: + cls.__annotations__ # noqa: B018 + except AttributeError: + # FIXME: Manual creation for python 3.9 (remove when 3.9 is dropped) + cls.__annotations__ = {} + + for fname, type_ in cls.__annotations__.items(): + if is_auto(type_): + type_ = Ordering # noqa: PLW2901 + + cls.__annotations__[fname] = Optional[type_] + + field_ = cls.__dict__.get(fname) + if not isinstance(field_, StrawberryField): + setattr(cls, fname, UNSET) + + if one_of is None: + one_of = strawberry_django_settings()["ORDERING_DEFAULT_ONE_OF"] + + return strawberry.input( + cls, + name=name, + one_of=one_of, + description=description, + directives=directives, + ) + + return wrapper @dataclass_transform( diff --git a/strawberry_django/settings.py b/strawberry_django/settings.py index 4390a17f..71f86281 100644 --- a/strawberry_django/settings.py +++ b/strawberry_django/settings.py @@ -46,6 +46,9 @@ class StrawberryDjangoSettings(TypedDict): #: to set it to unlimited. PAGINATION_DEFAULT_LIMIT: Optional[int] + #: Whether ordering inputs are marked with oneOf directive by default. + ORDERING_DEFAULT_ONE_OF: bool + DEFAULT_DJANGO_SETTINGS = StrawberryDjangoSettings( FIELD_DESCRIPTION_FROM_HELP_TEXT=False, @@ -57,6 +60,7 @@ class StrawberryDjangoSettings(TypedDict): DEFAULT_PK_FIELD_NAME="pk", USE_DEPRECATED_FILTERS=False, PAGINATION_DEFAULT_LIMIT=100, + ORDERING_DEFAULT_ONE_OF=False, ) diff --git a/strawberry_django/type.py b/strawberry_django/type.py index d4383ea1..271289c5 100644 --- a/strawberry_django/type.py +++ b/strawberry_django/type.py @@ -75,6 +75,7 @@ def _process_type( field_cls: type[StrawberryDjangoField] = StrawberryDjangoField, filters: Optional[type] = None, order: Optional[type] = None, + ordering: Optional[type] = None, pagination: bool = False, partial: bool = False, is_filter: Union[Literal["lookups"], bool] = False, @@ -133,6 +134,7 @@ def _process_type( is_filter=is_filter, filters=filters, order=order, + ordering=ordering, pagination=pagination, disable_optimization=disable_optimization, store=OptimizerStore.with_hints( @@ -409,6 +411,7 @@ class StrawberryDjangoDefinition(Generic[_O, _M]): is_filter: Union[Literal["lookups"], bool] = False filters: Optional[type] = None order: Optional[type] = None + ordering: Optional[type] = None pagination: bool = False field_cls: type[StrawberryDjangoField] = StrawberryDjangoField disable_optimization: bool = False @@ -434,6 +437,7 @@ def type( # noqa: A001 extend: bool = False, filters: Optional[type] = None, order: Optional[type] = None, + ordering: Optional[type] = None, pagination: bool = False, only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, @@ -471,6 +475,7 @@ def wrapper(cls: _T) -> _T: filters=filters, pagination=pagination, order=order, + ordering=ordering, only=only, select_related=select_related, prefetch_related=prefetch_related, diff --git a/tests/test_legacy_order.py b/tests/test_legacy_order.py new file mode 100644 index 00000000..7d953078 --- /dev/null +++ b/tests/test_legacy_order.py @@ -0,0 +1,444 @@ +# ruff: noqa: TRY002, B904, BLE001, F811, PT012 +from typing import Any, Optional, cast +from unittest import mock + +import pytest +import strawberry +from django.db.models import Case, Count, Value, When +from pytest_mock import MockFixture +from strawberry import auto +from strawberry.annotation import StrawberryAnnotation +from strawberry.exceptions import MissingArgumentsAnnotationsError +from strawberry.types import get_object_definition +from strawberry.types.base import ( + StrawberryOptional, + WithStrawberryObjectDefinition, + get_object_definition, +) +from strawberry.types.field import StrawberryField + +import strawberry_django +from strawberry_django.exceptions import ( + ForbiddenFieldArgumentError, + MissingFieldArgumentError, +) +from strawberry_django.fields.field import StrawberryDjangoField +from strawberry_django.fields.filter_order import ( + FilterOrderField, + FilterOrderFieldResolver, +) +from strawberry_django.ordering import Ordering, OrderSequence, process_order +from tests import models, utils +from tests.types import Fruit + + +@strawberry_django.ordering.order(models.Color) +class ColorOrder: + pk: auto + + @strawberry_django.order_field + def name(self, prefix, value: auto): + return [value.resolve(f"{prefix}name")] + + +@strawberry_django.ordering.order(models.Fruit) +class FruitOrder: + color_id: auto + name: auto + sweetness: auto + color: Optional[ColorOrder] + + @strawberry_django.order_field + def types_number(self, queryset, prefix, value: auto): + return queryset.annotate( + count=Count(f"{prefix}types__id"), + count_nulls=Case( + When(count=0, then=Value(None)), + default="count", + ), + ), [value.resolve("count_nulls")] + + +@strawberry_django.type(models.Fruit, order=FruitOrder) +class FruitWithOrder: + id: auto + name: auto + + +@strawberry.type +class Query: + fruits: list[Fruit] = strawberry_django.field(order=FruitOrder) + + +@pytest.fixture +def query(): + return utils.generate_query(Query) + + +def test_field_order_definition(): + field = StrawberryDjangoField(type_annotation=StrawberryAnnotation(FruitWithOrder)) + assert field.get_order() == FruitOrder + field = StrawberryDjangoField( + type_annotation=StrawberryAnnotation(FruitWithOrder), + filters=None, + ) + assert field.get_filters() is None + + +def test_asc(query, fruits): + result = query("{ fruits(order: { name: ASC }) { id name } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": "3", "name": "banana"}, + {"id": "2", "name": "raspberry"}, + {"id": "1", "name": "strawberry"}, + ] + + +def test_desc(query, fruits): + result = query("{ fruits(order: { name: DESC }) { id name } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": "1", "name": "strawberry"}, + {"id": "2", "name": "raspberry"}, + {"id": "3", "name": "banana"}, + ] + + +def test_relationship(query, fruits): + def add_color(fruit, color_name): + fruit.color = models.Color.objects.create(name=color_name) + fruit.save() + + color_names = ["red", "dark red", "yellow"] + for fruit, color_name in zip(fruits, color_names): + add_color(fruit, color_name) + result = query( + "{ fruits(order: { color: { name: DESC } }) { id name color { name } } }", + ) + assert not result.errors + assert result.data["fruits"] == [ + {"id": "3", "name": "banana", "color": {"name": "yellow"}}, + {"id": "1", "name": "strawberry", "color": {"name": "red"}}, + {"id": "2", "name": "raspberry", "color": {"name": "dark red"}}, + ] + + +def test_arguments_order_respected(query, db): + yellow = models.Color.objects.create(name="yellow") + red = models.Color.objects.create(name="red") + + f1 = models.Fruit.objects.create( + name="strawberry", + sweetness=1, + color=red, + ) + f2 = models.Fruit.objects.create( + name="banana", + sweetness=2, + color=yellow, + ) + f3 = models.Fruit.objects.create( + name="apple", + sweetness=0, + color=red, + ) + + result = query("{ fruits(order: { name: ASC, sweetness: ASC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f3, f2, f1]] + + result = query("{ fruits(order: { sweetness: DESC, name: ASC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f2, f1, f3]] + + result = query("{ fruits(order: { color: {name: ASC}, name: ASC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f3, f1, f2]] + + result = query("{ fruits(order: { color: {pk: ASC}, name: ASC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f2, f3, f1]] + + result = query("{ fruits(order: { colorId: ASC, name: ASC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f2, f3, f1]] + + result = query("{ fruits(order: { name: ASC, colorId: ASC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f3, f2, f1]] + + +def test_order_sequence(): + f1 = StrawberryField(graphql_name="sOmEnAmE", python_name="some_name") + f2 = StrawberryField(python_name="some_name") + + assert OrderSequence.get_graphql_name(None, f1) == "sOmEnAmE" + assert OrderSequence.get_graphql_name(None, f2) == "someName" + + assert OrderSequence.sorted(None, None, fields=[f1, f2]) == [f1, f2] + + sequence = {"someName": OrderSequence(0, None), "sOmEnAmE": OrderSequence(1, None)} + assert OrderSequence.sorted(None, sequence, fields=[f1, f2]) == [f1, f2] + + +def test_order_type(): + @strawberry_django.ordering.order(models.Fruit) + class FruitOrder: + color_id: auto + name: auto + sweetness: auto + + @strawberry_django.order_field + def custom_order(self, value: auto, prefix: str): + pass + + annotated_type = StrawberryOptional(Ordering._enum_definition) # type: ignore + + assert [ + ( + f.name, + f.__class__, + f.type, + f.base_resolver.__class__ if f.base_resolver else None, + ) + for f in get_object_definition(FruitOrder, strict=True).fields + ] == [ + ("color_id", StrawberryField, annotated_type, None), + ("name", StrawberryField, annotated_type, None), + ("sweetness", StrawberryField, annotated_type, None), + ( + "custom_order", + FilterOrderField, + annotated_type, + FilterOrderFieldResolver, + ), + ] + + +def test_order_field_missing_prefix(): + with pytest.raises( + MissingFieldArgumentError, match=r".*\"prefix\".*\"field_method\".*" + ): + + @strawberry_django.order_field + def field_method(): + pass + + +def test_order_field_missing_value(): + with pytest.raises( + MissingFieldArgumentError, match=r".*\"value\".*\"field_method\".*" + ): + + @strawberry_django.order_field + def field_method(prefix): + pass + + +def test_order_field_missing_value_annotation(): + with pytest.raises( + MissingArgumentsAnnotationsError, + match=r"Missing annotation.*\"value\".*\"field_method\".*", + ): + + @strawberry_django.order_field + def field_method(prefix, value): + pass + + +def test_order_field(): + try: + + @strawberry_django.order_field + def field_method(self, root, info, prefix, value: auto, sequence, queryset): + pass + + except Exception as exc: + raise pytest.fail(f"DID RAISE {exc}") + + +def test_order_field_forbidden_param_annotation(): + with pytest.raises( + MissingArgumentsAnnotationsError, + match=r".*\"forbidden_param\".*\"field_method\".*", + ): + + @strawberry_django.order_field + def field_method(prefix, value: auto, sequence, queryset, forbidden_param): + pass + + +def test_order_field_forbidden_param(): + with pytest.raises( + ForbiddenFieldArgumentError, + match=r".*\"forbidden_param\".*\"field_method\".*", + ): + + @strawberry_django.order_field + def field_method(prefix, value: auto, sequence, queryset, forbidden_param: str): + pass + + +def test_order_field_missing_queryset(): + with pytest.raises(MissingFieldArgumentError, match=r".*\"queryset\".*\"order\".*"): + + @strawberry_django.order_field + def order(prefix): + pass + + +def test_order_field_value_forbidden_on_object(): + with pytest.raises(ForbiddenFieldArgumentError, match=r".*\"value\".*\"order\".*"): + + @strawberry_django.order_field + def field_method(prefix, queryset, value: auto): + pass + + @strawberry_django.order_field + def order(prefix, queryset, value: auto): + pass + + +def test_order_field_on_object(): + try: + + @strawberry_django.order_field + def order(self, root, info, prefix, sequence, queryset): + pass + + except Exception as exc: + raise pytest.fail(f"DID RAISE {exc}") + + +def test_order_field_method(): + @strawberry_django.ordering.order(models.Fruit) + class Order: + @strawberry_django.order_field + def custom_order(self, root, info, prefix, value: auto, sequence, queryset): + assert self == order, "Unexpected self passed" + assert root == order, "Unexpected root passed" + assert info == fake_info, "Unexpected info passed" + assert prefix == "ROOT", "Unexpected prefix passed" + assert value == Ordering.ASC, "Unexpected value passed" + assert sequence == sequence_inner, "Unexpected sequence passed" + assert queryset == qs, "Unexpected queryset passed" + raise Exception("WAS CALLED") + + order = cast("WithStrawberryObjectDefinition", Order(custom_order=Ordering.ASC)) # type: ignore + schema = strawberry.Schema(query=Query) + fake_info: Any = type("FakeInfo", (), {"schema": schema}) + qs: Any = object() + sequence_inner: Any = object() + sequence = {"customOrder": OrderSequence(0, children=sequence_inner)} + + with pytest.raises(Exception, match="WAS CALLED"): + process_order(order, fake_info, qs, prefix="ROOT", sequence=sequence) + + +def test_order_method_not_called_when_not_decorated(mocker: MockFixture): + @strawberry_django.ordering.order(models.Fruit) + class Order: + def order(self, root, info, prefix, value: auto, sequence, queryset): + pytest.fail("Should not have been called") + + mock_order_method = mocker.spy(Order, "order") + + process_order( + cast("WithStrawberryObjectDefinition", Order()), mock.Mock(), mock.Mock() + ) + + mock_order_method.assert_not_called() + + +def test_order_field_not_called(mocker: MockFixture): + @strawberry_django.ordering.order(models.Fruit) + class Order: + order: Ordering = Ordering.ASC + + # Calling this and no error being raised is the test, as the wrong behavior would + # be for the field to be called like a method + process_order( + cast("WithStrawberryObjectDefinition", Order()), mock.Mock(), mock.Mock() + ) + + +def test_order_object_method(): + @strawberry_django.ordering.order(models.Fruit) + class Order: + @strawberry_django.order_field + def order(self, root, info, prefix, sequence, queryset): + assert self == order_, "Unexpected self passed" + assert root == order_, "Unexpected root passed" + assert info == fake_info, "Unexpected info passed" + assert prefix == "ROOT", "Unexpected prefix passed" + assert sequence == sequence_, "Unexpected sequence passed" + assert queryset == qs, "Unexpected queryset passed" + return queryset, ["name"] + + order_ = cast("WithStrawberryObjectDefinition", Order()) + schema = strawberry.Schema(query=Query) + fake_info: Any = type("FakeInfo", (), {"schema": schema}) + qs: Any = object() + sequence_: Any = {"customOrder": OrderSequence(0)} + + order = process_order(order_, fake_info, qs, prefix="ROOT", sequence=sequence_)[1] + assert "name" in order, "order was not called" + + +def test_order_nulls(query, db, fruits): + t1 = models.FruitType.objects.create(name="Type1") + t2 = models.FruitType.objects.create(name="Type2") + + f1, f2, f3 = models.Fruit.objects.all() + + f2.types.add(t1) + f3.types.add(t1, t2) + + result = query("{ fruits(order: { typesNumber: ASC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f1.id)}, + {"id": str(f2.id)}, + {"id": str(f3.id)}, + ] + + result = query("{ fruits(order: { typesNumber: DESC }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f3.id)}, + {"id": str(f2.id)}, + {"id": str(f1.id)}, + ] + + result = query("{ fruits(order: { typesNumber: ASC_NULLS_FIRST }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f1.id)}, + {"id": str(f2.id)}, + {"id": str(f3.id)}, + ] + + result = query("{ fruits(order: { typesNumber: ASC_NULLS_LAST }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f2.id)}, + {"id": str(f3.id)}, + {"id": str(f1.id)}, + ] + + result = query("{ fruits(order: { typesNumber: DESC_NULLS_LAST }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f3.id)}, + {"id": str(f2.id)}, + {"id": str(f1.id)}, + ] + + result = query("{ fruits(order: { typesNumber: DESC_NULLS_FIRST }) { id } }") + assert not result.errors + assert result.data["fruits"] == [ + {"id": str(f1.id)}, + {"id": str(f3.id)}, + {"id": str(f2.id)}, + ] diff --git a/tests/test_ordering.py b/tests/test_ordering.py index 7d953078..4cbb8bac 100644 --- a/tests/test_ordering.py +++ b/tests/test_ordering.py @@ -1,38 +1,30 @@ # ruff: noqa: TRY002, B904, BLE001, F811, PT012 -from typing import Any, Optional, cast -from unittest import mock +from typing import Optional import pytest import strawberry from django.db.models import Case, Count, Value, When -from pytest_mock import MockFixture from strawberry import auto from strawberry.annotation import StrawberryAnnotation -from strawberry.exceptions import MissingArgumentsAnnotationsError -from strawberry.types import get_object_definition from strawberry.types.base import ( StrawberryOptional, - WithStrawberryObjectDefinition, get_object_definition, ) from strawberry.types.field import StrawberryField import strawberry_django -from strawberry_django.exceptions import ( - ForbiddenFieldArgumentError, - MissingFieldArgumentError, -) + from strawberry_django.fields.field import StrawberryDjangoField from strawberry_django.fields.filter_order import ( FilterOrderField, FilterOrderFieldResolver, ) -from strawberry_django.ordering import Ordering, OrderSequence, process_order +from strawberry_django.ordering import Ordering from tests import models, utils from tests.types import Fruit -@strawberry_django.ordering.order(models.Color) +@strawberry_django.ordering.ordering(models.Color) class ColorOrder: pk: auto @@ -41,7 +33,7 @@ def name(self, prefix, value: auto): return [value.resolve(f"{prefix}name")] -@strawberry_django.ordering.order(models.Fruit) +@strawberry_django.ordering.ordering(models.Fruit) class FruitOrder: color_id: auto name: auto @@ -59,7 +51,7 @@ def types_number(self, queryset, prefix, value: auto): ), [value.resolve("count_nulls")] -@strawberry_django.type(models.Fruit, order=FruitOrder) +@strawberry_django.type(models.Fruit, ordering=FruitOrder) class FruitWithOrder: id: auto name: auto @@ -67,7 +59,7 @@ class FruitWithOrder: @strawberry.type class Query: - fruits: list[Fruit] = strawberry_django.field(order=FruitOrder) + fruits: list[Fruit] = strawberry_django.field(ordering=FruitOrder) @pytest.fixture @@ -77,16 +69,16 @@ def query(): def test_field_order_definition(): field = StrawberryDjangoField(type_annotation=StrawberryAnnotation(FruitWithOrder)) - assert field.get_order() == FruitOrder + assert field.get_ordering() == FruitOrder field = StrawberryDjangoField( type_annotation=StrawberryAnnotation(FruitWithOrder), - filters=None, + ordering=None, ) - assert field.get_filters() is None + assert field.get_ordering() is None def test_asc(query, fruits): - result = query("{ fruits(order: { name: ASC }) { id name } }") + result = query("{ fruits(ordering: [{ name: ASC }]) { id name } }") assert not result.errors assert result.data["fruits"] == [ {"id": "3", "name": "banana"}, @@ -96,7 +88,7 @@ def test_asc(query, fruits): def test_desc(query, fruits): - result = query("{ fruits(order: { name: DESC }) { id name } }") + result = query("{ fruits(ordering: [{ name: DESC }]) { id name } }") assert not result.errors assert result.data["fruits"] == [ {"id": "1", "name": "strawberry"}, @@ -105,6 +97,21 @@ def test_desc(query, fruits): ] +def test_multi_order(query, db): + for fruit in ("strawberry", "banana", "raspberry"): + models.Fruit.objects.create(name=fruit, sweetness=7) + + result = query( + "{ fruits(ordering: [{ sweetness: ASC }, { name: ASC }]) { id name sweetness } }" + ) + assert not result.errors + assert result.data["fruits"] == [ + {"id": "2", "name": "banana", "sweetness": 7}, + {"id": "3", "name": "raspberry", "sweetness": 7}, + {"id": "1", "name": "strawberry", "sweetness": 7}, + ] + + def test_relationship(query, fruits): def add_color(fruit, color_name): fruit.color = models.Color.objects.create(name=color_name) @@ -114,7 +121,7 @@ def add_color(fruit, color_name): for fruit, color_name in zip(fruits, color_names): add_color(fruit, color_name) result = query( - "{ fruits(order: { color: { name: DESC } }) { id name color { name } } }", + "{ fruits(ordering: [{ color: { name: DESC } }]) { id name color { name } } }", ) assert not result.errors assert result.data["fruits"] == [ @@ -124,7 +131,7 @@ def add_color(fruit, color_name): ] -def test_arguments_order_respected(query, db): +def test_multi_order_respected(query, db): yellow = models.Color.objects.create(name="yellow") red = models.Color.objects.create(name="red") @@ -144,46 +151,35 @@ def test_arguments_order_respected(query, db): color=red, ) - result = query("{ fruits(order: { name: ASC, sweetness: ASC }) { id } }") + result = query("{ fruits(ordering: [{ name: ASC }, { sweetness: ASC }]) { id } }") assert not result.errors assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f3, f2, f1]] - result = query("{ fruits(order: { sweetness: DESC, name: ASC }) { id } }") + result = query("{ fruits(ordering: [{ sweetness: DESC }, { name: ASC }]) { id } }") assert not result.errors assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f2, f1, f3]] - result = query("{ fruits(order: { color: {name: ASC}, name: ASC }) { id } }") + result = query( + "{ fruits(ordering: [{ color: {name: ASC} }, { name: ASC }]) { id } }" + ) assert not result.errors assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f3, f1, f2]] - result = query("{ fruits(order: { color: {pk: ASC}, name: ASC }) { id } }") + result = query("{ fruits(ordering: [{ color: {pk: ASC} }, { name: ASC }]) { id } }") assert not result.errors assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f2, f3, f1]] - result = query("{ fruits(order: { colorId: ASC, name: ASC }) { id } }") + result = query("{ fruits(ordering: [{ colorId: ASC }, { name: ASC }]) { id } }") assert not result.errors assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f2, f3, f1]] - result = query("{ fruits(order: { name: ASC, colorId: ASC }) { id } }") + result = query("{ fruits(ordering: [{ name: ASC }, { colorId: ASC }]) { id } }") assert not result.errors assert result.data["fruits"] == [{"id": str(f.pk)} for f in [f3, f2, f1]] -def test_order_sequence(): - f1 = StrawberryField(graphql_name="sOmEnAmE", python_name="some_name") - f2 = StrawberryField(python_name="some_name") - - assert OrderSequence.get_graphql_name(None, f1) == "sOmEnAmE" - assert OrderSequence.get_graphql_name(None, f2) == "someName" - - assert OrderSequence.sorted(None, None, fields=[f1, f2]) == [f1, f2] - - sequence = {"someName": OrderSequence(0, None), "sOmEnAmE": OrderSequence(1, None)} - assert OrderSequence.sorted(None, sequence, fields=[f1, f2]) == [f1, f2] - - def test_order_type(): - @strawberry_django.ordering.order(models.Fruit) + @strawberry_django.ordering.ordering(models.Fruit) class FruitOrder: color_id: auto name: auto @@ -216,176 +212,6 @@ def custom_order(self, value: auto, prefix: str): ] -def test_order_field_missing_prefix(): - with pytest.raises( - MissingFieldArgumentError, match=r".*\"prefix\".*\"field_method\".*" - ): - - @strawberry_django.order_field - def field_method(): - pass - - -def test_order_field_missing_value(): - with pytest.raises( - MissingFieldArgumentError, match=r".*\"value\".*\"field_method\".*" - ): - - @strawberry_django.order_field - def field_method(prefix): - pass - - -def test_order_field_missing_value_annotation(): - with pytest.raises( - MissingArgumentsAnnotationsError, - match=r"Missing annotation.*\"value\".*\"field_method\".*", - ): - - @strawberry_django.order_field - def field_method(prefix, value): - pass - - -def test_order_field(): - try: - - @strawberry_django.order_field - def field_method(self, root, info, prefix, value: auto, sequence, queryset): - pass - - except Exception as exc: - raise pytest.fail(f"DID RAISE {exc}") - - -def test_order_field_forbidden_param_annotation(): - with pytest.raises( - MissingArgumentsAnnotationsError, - match=r".*\"forbidden_param\".*\"field_method\".*", - ): - - @strawberry_django.order_field - def field_method(prefix, value: auto, sequence, queryset, forbidden_param): - pass - - -def test_order_field_forbidden_param(): - with pytest.raises( - ForbiddenFieldArgumentError, - match=r".*\"forbidden_param\".*\"field_method\".*", - ): - - @strawberry_django.order_field - def field_method(prefix, value: auto, sequence, queryset, forbidden_param: str): - pass - - -def test_order_field_missing_queryset(): - with pytest.raises(MissingFieldArgumentError, match=r".*\"queryset\".*\"order\".*"): - - @strawberry_django.order_field - def order(prefix): - pass - - -def test_order_field_value_forbidden_on_object(): - with pytest.raises(ForbiddenFieldArgumentError, match=r".*\"value\".*\"order\".*"): - - @strawberry_django.order_field - def field_method(prefix, queryset, value: auto): - pass - - @strawberry_django.order_field - def order(prefix, queryset, value: auto): - pass - - -def test_order_field_on_object(): - try: - - @strawberry_django.order_field - def order(self, root, info, prefix, sequence, queryset): - pass - - except Exception as exc: - raise pytest.fail(f"DID RAISE {exc}") - - -def test_order_field_method(): - @strawberry_django.ordering.order(models.Fruit) - class Order: - @strawberry_django.order_field - def custom_order(self, root, info, prefix, value: auto, sequence, queryset): - assert self == order, "Unexpected self passed" - assert root == order, "Unexpected root passed" - assert info == fake_info, "Unexpected info passed" - assert prefix == "ROOT", "Unexpected prefix passed" - assert value == Ordering.ASC, "Unexpected value passed" - assert sequence == sequence_inner, "Unexpected sequence passed" - assert queryset == qs, "Unexpected queryset passed" - raise Exception("WAS CALLED") - - order = cast("WithStrawberryObjectDefinition", Order(custom_order=Ordering.ASC)) # type: ignore - schema = strawberry.Schema(query=Query) - fake_info: Any = type("FakeInfo", (), {"schema": schema}) - qs: Any = object() - sequence_inner: Any = object() - sequence = {"customOrder": OrderSequence(0, children=sequence_inner)} - - with pytest.raises(Exception, match="WAS CALLED"): - process_order(order, fake_info, qs, prefix="ROOT", sequence=sequence) - - -def test_order_method_not_called_when_not_decorated(mocker: MockFixture): - @strawberry_django.ordering.order(models.Fruit) - class Order: - def order(self, root, info, prefix, value: auto, sequence, queryset): - pytest.fail("Should not have been called") - - mock_order_method = mocker.spy(Order, "order") - - process_order( - cast("WithStrawberryObjectDefinition", Order()), mock.Mock(), mock.Mock() - ) - - mock_order_method.assert_not_called() - - -def test_order_field_not_called(mocker: MockFixture): - @strawberry_django.ordering.order(models.Fruit) - class Order: - order: Ordering = Ordering.ASC - - # Calling this and no error being raised is the test, as the wrong behavior would - # be for the field to be called like a method - process_order( - cast("WithStrawberryObjectDefinition", Order()), mock.Mock(), mock.Mock() - ) - - -def test_order_object_method(): - @strawberry_django.ordering.order(models.Fruit) - class Order: - @strawberry_django.order_field - def order(self, root, info, prefix, sequence, queryset): - assert self == order_, "Unexpected self passed" - assert root == order_, "Unexpected root passed" - assert info == fake_info, "Unexpected info passed" - assert prefix == "ROOT", "Unexpected prefix passed" - assert sequence == sequence_, "Unexpected sequence passed" - assert queryset == qs, "Unexpected queryset passed" - return queryset, ["name"] - - order_ = cast("WithStrawberryObjectDefinition", Order()) - schema = strawberry.Schema(query=Query) - fake_info: Any = type("FakeInfo", (), {"schema": schema}) - qs: Any = object() - sequence_: Any = {"customOrder": OrderSequence(0)} - - order = process_order(order_, fake_info, qs, prefix="ROOT", sequence=sequence_)[1] - assert "name" in order, "order was not called" - - def test_order_nulls(query, db, fruits): t1 = models.FruitType.objects.create(name="Type1") t2 = models.FruitType.objects.create(name="Type2") @@ -395,7 +221,7 @@ def test_order_nulls(query, db, fruits): f2.types.add(t1) f3.types.add(t1, t2) - result = query("{ fruits(order: { typesNumber: ASC }) { id } }") + result = query("{ fruits(ordering: [{ typesNumber: ASC }]) { id } }") assert not result.errors assert result.data["fruits"] == [ {"id": str(f1.id)}, @@ -403,7 +229,7 @@ def test_order_nulls(query, db, fruits): {"id": str(f3.id)}, ] - result = query("{ fruits(order: { typesNumber: DESC }) { id } }") + result = query("{ fruits(ordering: [{ typesNumber: DESC }]) { id } }") assert not result.errors assert result.data["fruits"] == [ {"id": str(f3.id)}, @@ -411,7 +237,7 @@ def test_order_nulls(query, db, fruits): {"id": str(f1.id)}, ] - result = query("{ fruits(order: { typesNumber: ASC_NULLS_FIRST }) { id } }") + result = query("{ fruits(ordering: [{ typesNumber: ASC_NULLS_FIRST }]) { id } }") assert not result.errors assert result.data["fruits"] == [ {"id": str(f1.id)}, @@ -419,7 +245,7 @@ def test_order_nulls(query, db, fruits): {"id": str(f3.id)}, ] - result = query("{ fruits(order: { typesNumber: ASC_NULLS_LAST }) { id } }") + result = query("{ fruits(ordering: [{ typesNumber: ASC_NULLS_LAST }]) { id } }") assert not result.errors assert result.data["fruits"] == [ {"id": str(f2.id)}, @@ -427,7 +253,7 @@ def test_order_nulls(query, db, fruits): {"id": str(f1.id)}, ] - result = query("{ fruits(order: { typesNumber: DESC_NULLS_LAST }) { id } }") + result = query("{ fruits(ordering: [{ typesNumber: DESC_NULLS_LAST }]) { id } }") assert not result.errors assert result.data["fruits"] == [ {"id": str(f3.id)}, @@ -435,7 +261,7 @@ def test_order_nulls(query, db, fruits): {"id": str(f1.id)}, ] - result = query("{ fruits(order: { typesNumber: DESC_NULLS_FIRST }) { id } }") + result = query("{ fruits(ordering: [{ typesNumber: DESC_NULLS_FIRST }]) { id } }") assert not result.errors assert result.data["fruits"] == [ {"id": str(f1.id)}, diff --git a/tests/test_settings.py b/tests/test_settings.py index 3c7cf1d6..0217a7a8 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -31,6 +31,7 @@ def test_non_defaults(): DEFAULT_PK_FIELD_NAME="id", USE_DEPRECATED_FILTERS=True, PAGINATION_DEFAULT_LIMIT=250, + ORDERING_DEFAULT_ONE_OF=True, ), ): assert ( @@ -45,5 +46,6 @@ def test_non_defaults(): DEFAULT_PK_FIELD_NAME="id", USE_DEPRECATED_FILTERS=True, PAGINATION_DEFAULT_LIMIT=250, + ORDERING_DEFAULT_ONE_OF=True, ) ) diff --git a/tests/types.py b/tests/types.py index 42648a36..68d4eff9 100644 --- a/tests/types.py +++ b/tests/types.py @@ -15,6 +15,7 @@ class Fruit: color: Color | None types: list[FruitType] picture: auto + sweetness: auto @strawberry_django.type(models.Color) From cb815bd97c66bfd07ee7732b562ec900c68f7ca5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Dec 2024 20:13:36 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- strawberry_django/ordering.py | 9 ++------- tests/test_ordering.py | 1 - 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/strawberry_django/ordering.py b/strawberry_django/ordering.py index 5602bec6..c9fd72de 100644 --- a/strawberry_django/ordering.py +++ b/strawberry_django/ordering.py @@ -2,24 +2,20 @@ import dataclasses import enum -import functools from typing import ( TYPE_CHECKING, Callable, Optional, TypeVar, cast, - Any, - Mapping, ) import strawberry from django.db.models import F, OrderBy, QuerySet -from graphql import VariableNode from graphql.language.ast import ObjectValueNode from strawberry import UNSET from strawberry.types import has_object_definition -from strawberry.types.base import WithStrawberryObjectDefinition, StrawberryOptional +from strawberry.types.base import StrawberryOptional, WithStrawberryObjectDefinition from strawberry.types.field import StrawberryField, field from strawberry.types.unset import UnsetType from strawberry.utils.str_converters import to_camel_case @@ -260,8 +256,7 @@ def process_ordering( order_method := getattr(ordering_cls, "process_ordering", None), ): return order_method(order, info, queryset=queryset, prefix=prefix) - else: - return process_ordering_default(ordering, info, queryset, prefix) + return process_ordering_default(ordering, info, queryset, prefix) def apply_ordering( diff --git a/tests/test_ordering.py b/tests/test_ordering.py index 4cbb8bac..b7bf7353 100644 --- a/tests/test_ordering.py +++ b/tests/test_ordering.py @@ -13,7 +13,6 @@ from strawberry.types.field import StrawberryField import strawberry_django - from strawberry_django.fields.field import StrawberryDjangoField from strawberry_django.fields.filter_order import ( FilterOrderField,