diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 467206d0..a8ec95a7 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -2,6 +2,7 @@ import contextlib import contextvars +import copy import dataclasses import itertools from collections import defaultdict @@ -196,8 +197,10 @@ def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo): if isinstance(p, str): prefetch_related.append(f"{prefix}{LOOKUP_SEP}{p}") elif isinstance(p, Prefetch): - p.add_prefix(prefix) - prefetch_related.append(p) + # add_prefix modifies the field's prefetch object, so we copy it before + p_copy = copy.copy(p) + p_copy.add_prefix(prefix) + prefetch_related.append(p_copy) else: # pragma:nocover assert_never(p) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 5c8d504b..3e5846b9 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -2,11 +2,15 @@ from typing import Any, List, cast import pytest +import strawberry +from django.db.models import Prefetch from django.utils import timezone from strawberry.relay import to_base64 +import strawberry_django from strawberry_django.optimizer import DjangoOptimizerExtension +from . import utils from .projects.faker import ( IssueFactory, MilestoneFactory, @@ -15,7 +19,7 @@ TagFactory, UserFactory, ) -from .projects.models import Assignee, Issue +from .projects.models import Assignee, Issue, Milestone, Project from .utils import GraphQLTestClient, assert_num_queries @@ -816,3 +820,72 @@ def test_query_annotate_with_callable(db, gql_client: GraphQLTestClient): asserts_errors=False, ) assert res.errors + + +@pytest.mark.django_db(transaction=True) +def test_user_query_with_prefetch(): + @strawberry_django.type( + Project, + ) + class ProjectTypeWithPrefetch: + @strawberry_django.field( + prefetch_related=[ + Prefetch( + "milestones", + queryset=Milestone.objects.all(), + to_attr="prefetched_milestones", + ), + ], + ) + def custom_field(self, info) -> str: + if hasattr(self, "prefetched_milestones"): + return "prefetched" + return "not prefetched" + + @strawberry_django.type( + Milestone, + ) + class MilestoneTypeWithNestedPrefetch: + project: ProjectTypeWithPrefetch + + MilestoneFactory.create() + + @strawberry.type + class Query: + milestones: List[MilestoneTypeWithNestedPrefetch] = strawberry_django.field() + + query = utils.generate_query(Query, enable_optimizer=True) + query_str = """ + query TestQuery { + milestones { + project { + customField + } + } + } + """ + assert DjangoOptimizerExtension.enabled.get() + result = query(query_str) + + assert not result.errors + assert result.data == { + "milestones": [ + { + "project": { + "customField": "prefetched", + }, + }, + ], + } + + result2 = query(query_str) + assert not result2.errors + assert result2.data == { + "milestones": [ + { + "project": { + "customField": "prefetched", + }, + }, + ], + } diff --git a/tests/utils.py b/tests/utils.py index 2a4c4128..b054eceb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,7 @@ from strawberry.test.client import Response from strawberry.utils.inspect import in_async_context +from strawberry_django.optimizer import DjangoOptimizerExtension from strawberry_django.test.client import TestClient _client: contextvars.ContextVar["GraphQLTestClient"] = contextvars.ContextVar( @@ -22,7 +23,7 @@ ) -def generate_query(query=None, mutation=None): +def generate_query(query=None, mutation=None, enable_optimizer=False): append_mutation = mutation and not query if query is None: @@ -31,7 +32,11 @@ class Query: x: int query = Query - schema = strawberry.Schema(query=query, mutation=mutation) + extensions = [] + + if enable_optimizer: + extensions = [DjangoOptimizerExtension()] + schema = strawberry.Schema(query=query, mutation=mutation, extensions=extensions) def process_result(result): return result