From 06f62c74a37fc20d3122e7528add8e6c6119e591 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Tue, 7 Jan 2025 19:01:14 +0100 Subject: [PATCH] fix(mutations): Make sure we skip refetch when the optimizer is disabled --- strawberry_django/mutations/fields.py | 2 +- tests/test_input_mutations.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/strawberry_django/mutations/fields.py b/strawberry_django/mutations/fields.py index a27ebaba..e1d3ee5e 100644 --- a/strawberry_django/mutations/fields.py +++ b/strawberry_django/mutations/fields.py @@ -235,7 +235,7 @@ def arguments(self, value: list[StrawberryArgument]): return args_prop.fset(self, value) # type: ignore def refetch(self, resolved: _T, *, info: Info | None) -> _T: - if not DjangoOptimizerExtension.enabled or info is None: + if not DjangoOptimizerExtension.enabled.get() or info is None: return resolved if isinstance(resolved, list) and resolved: diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index 68fa69ae..8c657c27 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -4,6 +4,7 @@ from django.core.exceptions import ValidationError from strawberry.relay import from_base64, to_base64 +from strawberry_django.optimizer import DjangoOptimizerExtension from tests.utils import GraphQLTestClient, assert_num_queries from .projects.faker import ( @@ -1028,7 +1029,7 @@ def test_input_nested_update_mutation(db, gql_client: GraphQLTestClient): @pytest.mark.django_db(transaction=True) def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClient): query = """ - mutation UpdateProject ($input: ProjectInputPartial!) { + mutation UpdateProject ($input: ProjectInputPartial!, $optimizerEnabled: Boolean!) { updateProject (input: $input) { __typename ... on OperationInfo { @@ -1042,7 +1043,7 @@ def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClien id name dueDate - isDelayed + isDelayed @include(if: $optimizerEnabled) milestones { id name @@ -1059,7 +1060,9 @@ def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClien milestone_1_id = to_base64("MilestoneType", milestone_1.pk) MilestoneFactory.create(project=project) - with assert_num_queries(14): + # For mutations, having the optimizer enabled is expected to generate one extra + # query for the refetch of the object + with assert_num_queries(14 if DjangoOptimizerExtension.enabled.get() else 13): res = gql_client.query( query, { @@ -1067,6 +1070,7 @@ def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClien "id": to_base64("ProjectType", project.pk), "milestones": [{"id": milestone_1_id}], }, + "optimizerEnabled": DjangoOptimizerExtension.enabled.get(), }, )