diff --git a/CHANGES.rst b/CHANGES.rst index ad37fed5..38427e90 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,7 @@ Changelog To be released ------------------ - Add support for `Python 3.13` (GH-#628) +- Add support for `prefetch_related` to `InheritanceManager` 5.0.0 (2024-09-01) ------------------ diff --git a/docs/managers.rst b/docs/managers.rst index c5bb2ce5..58b26a14 100644 --- a/docs/managers.rst +++ b/docs/managers.rst @@ -84,6 +84,14 @@ If you don't explicitly call ``select_subclasses()`` or ``get_subclass()``, an ``InheritanceManager`` behaves identically to a normal ``Manager``; so it's safe to use as your default manager for the model. +``InheritanceManager`` supports ``prefetch_related``, even in subclasses: + +.. code-block:: python + + places = Place.objects.select_subclasses().prefetch_related('bar__manager') + # every Bar instance in places will have it's manager relation prefetched + + .. _contributed by Jeff Elmore: https://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/ JoinQueryset diff --git a/model_utils/managers.py b/model_utils/managers.py index 4cb1ffcd..9ef361c7 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -5,9 +5,15 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import connection, models +from django.db.models import Q from django.db.models.constants import LOOKUP_SEP from django.db.models.fields.related import OneToOneField, OneToOneRel -from django.db.models.query import ModelIterable, QuerySet +from django.db.models.query import ( + ModelIterable, + Prefetch, + QuerySet, + prefetch_related_objects, +) from django.db.models.sql.datastructures import Join ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) @@ -65,6 +71,8 @@ def __iter__(self): class InheritanceQuerySetMixin(Generic[ModelT]): + _prefetch_related_lookups: Sequence[str | Prefetch] + _result_cache: list[ModelT] model: type[ModelT] subclasses: Sequence[str] @@ -105,6 +113,79 @@ def select_subclasses(self, *subclasses: str | type[models.Model]) -> Inheritanc new_qs.subclasses = selected_subclasses return new_qs + def _prefetch_related_objects(self): + # Step 1: Find the base objects + # self._result_cache contains the subclasses as returned by InheritanceIterable + # walk up the path_to_parent to get to the parent model for each + _base_objs = [] + sub_obj: ModelT + for sub_obj in self._result_cache: + for p in sub_obj._meta.get_path_to_parent(self.model): + sub_obj = getattr(sub_obj, p.join_field.name) + _base_objs.append(sub_obj) + + # Step 2: Prefetch using the base objects + # This satisfies the requirement of prefetch_related_objects that the list be homogeneous + # This allows the user to use prefetch_related(subclass__subclass_relation) + # Because InheritanceIterable transforms the result into "subclass", then subclass_relation will + # be prefetched on that subclass object, as expected + prefetch_related_objects(_base_objs, *self._prefetch_related_lookups) + + # Step 3: Copy down the prefetched objects + # Assuming we have the inheritance C extends B extends A + # If a relation is prefetched at B, we must put those prefetched objects into the C's + # _prefetched_objects_cache as well, so that when a C object is returned (which is obtained + # by InheritanceIterable via base_obj.b.c) and the user does sub_obj.m2m_field.all() + # then Django's ManyRelatedManager will look into base_obj.b.c._prefetched_objects_cache + # but prefetch_related_objects above has put the prefetched objects into base_obj.b._prefetched_objects_cache + # The same goes for _state.fields_cache, which is used by ForeignKeys + # ForeignKeys already make an attempt to look at the parent's fields_cache, but it only works for one level + # Additionally, copy any to_attr prefetches down as well + for sub_obj, base_obj in zip(self._result_cache, _base_objs): + # get the base caches or create a new a blank one if there isn't any + prefetch_cache = getattr(base_obj, '_prefetched_objects_cache', None) + if prefetch_cache is not None: + prefetch_cache = dict(prefetch_cache) + else: + prefetch_cache = {} + fields_cache = dict(base_obj._state.fields_cache) + + current = base_obj + current_path = [] + prefetch_attrs = {} + for p in sub_obj._meta.get_path_from_parent(self.model): + join_field_name = p.join_field.name + current_path.append(join_field_name) + current = getattr(current, join_field_name) + child_cache: dict | None = getattr(current, '_prefetched_objects_cache', None) + if child_cache is not None: + # The child already has its own cache, add it to the running list of prefetches + prefetch_cache.update(child_cache) + if prefetch_cache: + # If we have something prefetched at this level or above, put it in this sub_obj + current._prefetched_objects_cache = prefetch_cache + # prepare a fresh dict for the next level down + prefetch_cache = dict(prefetch_cache) + + child_fields_cache = current._state.fields_cache + if child_fields_cache: + fields_cache.update(child_fields_cache) + if fields_cache: + current._state.fields_cache = fields_cache + fields_cache = dict(fields_cache) + + for prefetch in self._prefetch_related_lookups: + if isinstance(prefetch, Prefetch) and prefetch.to_attr: + prefetch_path = prefetch.prefetch_to.split(LOOKUP_SEP)[:-1] + if current_path == prefetch_path: + # The prefetch was at this level exactly, get the prefetch from the object + prefetch_attrs[prefetch] = getattr(current, prefetch.to_attr) + elif current_path[:len(prefetch_path)] == prefetch_path: + # the prefetch was for a parent of this one, get it from the running cache + setattr(current, prefetch.to_attr, prefetch_attrs[prefetch]) + + self._prefetch_done = True + def _chain(self, **kwargs: object) -> InheritanceQuerySet[ModelT]: update = {} for name in ['subclasses', '_annotated']: @@ -165,18 +246,10 @@ def _get_ancestors_path(self, model: type[models.Model]) -> str: raise ValueError( f"{model!r} is not a subclass of {self.model!r}") - ancestry: list[str] = [] - # should be a OneToOneField or None - parent_link = model._meta.get_ancestor_link(self.model) - - while parent_link is not None: - related = parent_link.remote_field - ancestry.insert(0, related.get_accessor_name()) - - parent_model = related.model - parent_link = parent_model._meta.get_ancestor_link(self.model) - - return LOOKUP_SEP.join(ancestry) + return LOOKUP_SEP.join( + p.join_field.get_accessor_name() + for p in model._meta.get_path_from_parent(self.model) + ) def _get_sub_obj_recurse(self, obj: models.Model, s: str) -> ModelT | None: rel, _, s = s.partition(LOOKUP_SEP) @@ -212,18 +285,18 @@ def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]: # Due to https://code.djangoproject.com/ticket/16572, we # can't really do this for anything other than children (ie, # no grandchildren+). - where_queries = [] + conditions = [] for model in models: - where_queries.append('(' + ' AND '.join([ - '"{}"."{}" IS NOT NULL'.format( - model._meta.db_table, - field.column, - ) for field in model._meta.parents.values() - ]) + ')') + path_from_parent = LOOKUP_SEP.join( + p.join_field.get_accessor_name() for p in model._meta.get_path_from_parent(self.model) + ) + conditions.append( + (path_from_parent + LOOKUP_SEP + 'isnull', False) + ) return cast( 'InheritanceQuerySet[ModelT]', - self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) + self.select_subclasses(*models).filter(Q(*conditions, _connector=Q.OR)) ) diff --git a/tests/models.py b/tests/models.py index 4d345050..3385e59c 100644 --- a/tests/models.py +++ b/tests/models.py @@ -44,6 +44,8 @@ class InheritanceManagerTestParent(models.Model): related_self = models.OneToOneField( "self", related_name="imtests_self", null=True, on_delete=models.CASCADE) + normal_relation_parent = models.ForeignKey('InheritanceManagerNonChild', null=True, on_delete=models.CASCADE) + normal_many_relation_parent = models.ManyToManyField('InheritanceManagerNonChild') objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager() def __str__(self) -> str: @@ -56,6 +58,8 @@ def __str__(self) -> str: class InheritanceManagerTestChild1(InheritanceManagerTestParent): non_related_field_using_descriptor_2 = models.FileField(upload_to="test") normal_field_2 = models.TextField() + normal_relation = models.ForeignKey('InheritanceManagerNonChild', null=True, on_delete=models.CASCADE) + normal_many_relation = models.ManyToManyField('InheritanceManagerNonChild') objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager() @@ -95,6 +99,10 @@ class InheritanceManagerTestChild4(InheritanceManagerTestParent): parent_link=True, on_delete=models.CASCADE) +class InheritanceManagerNonChild(models.Model): + name = models.CharField(max_length=255) + + class TimeStamp(TimeStampedModel): test_field = models.PositiveSmallIntegerField(default=0) diff --git a/tests/test_managers/test_inheritance_manager.py b/tests/test_managers/test_inheritance_manager.py index 68e8a743..2e301b64 100644 --- a/tests/test_managers/test_inheritance_manager.py +++ b/tests/test_managers/test_inheritance_manager.py @@ -3,10 +3,12 @@ from typing import TYPE_CHECKING from django.db import models +from django.db.models import Prefetch from django.test import TestCase from model_utils.managers import InheritanceManager from tests.models import ( + InheritanceManagerNonChild, InheritanceManagerTestChild1, InheritanceManagerTestChild2, InheritanceManagerTestChild3, @@ -541,3 +543,144 @@ def test_annotate_with_named_arguments_before_select_subclasses(self) -> None: def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self) -> None: qs = InheritanceManagerTestParent.objects.select_subclasses() self.assertEqual(qs.subclasses, qs._clone().subclasses) + + +class InheritanceManagerPrefetchForeignKeyTests(TestCase): + + def setUp(self) -> None: + self.related1 = InheritanceManagerNonChild.objects.create() + self.related2 = InheritanceManagerNonChild.objects.create() + self.related3 = InheritanceManagerNonChild.objects.create() + self.related4 = InheritanceManagerNonChild.objects.create() + self.related5 = InheritanceManagerNonChild.objects.create() + self.c1 = InheritanceManagerTestChild1.objects.create( + normal_relation=self.related1, + normal_relation_parent=self.related1, + ) + self.c1.normal_many_relation.set([self.related1, self.related2]) + self.c1.normal_many_relation_parent.set([self.related1, self.related2]) + + self.c2 = InheritanceManagerTestChild1.objects.create( + normal_relation=self.related2, + normal_relation_parent=self.related2, + ) + self.c2.normal_many_relation.set([self.related3, self.related4]) + self.c2.normal_many_relation_parent.set([self.related3, self.related4]) + + self.gc1 = InheritanceManagerTestGrandChild1.objects.create( + normal_relation=self.related1, + normal_relation_parent=self.related1, + ) + self.gc1.normal_many_relation.set([self.related1, self.related2]) + self.gc1.normal_many_relation_parent.set([self.related1, self.related2]) + + self.gc2 = InheritanceManagerTestGrandChild1.objects.create( + normal_relation=self.related2, + normal_relation_parent=self.related2, + ) + self.gc2.normal_many_relation.set([self.related3, self.related4]) + self.gc2.normal_many_relation_parent.set([self.related3, self.related4]) + + self.gc3 = InheritanceManagerTestGrandChild1_2.objects.create( + normal_relation=self.related1, + normal_relation_parent=self.related1, + ) + self.gc3.normal_many_relation.set([self.related1, self.related2]) + self.gc3.normal_many_relation_parent.set([self.related1, self.related2]) + + self.gc4 = InheritanceManagerTestGrandChild1.objects.create( + normal_relation=self.related3, + normal_relation_parent=self.related3, + ) + self.gc4.normal_many_relation.set([self.related5, self.related1]) + self.gc4.normal_many_relation_parent.set([self.related5, self.related1]) + + self.c3 = InheritanceManagerTestChild2.objects.create() + + def test_prefetch_related_works_with_fk_in_parent(self) -> None: + with self.assertNumQueries(2): + result = list( + InheritanceManagerTestParent.objects.select_subclasses().prefetch_related( + 'normal_relation_parent' + ).order_by('pk') + ) + self.assertEqual(result[0], self.c1) + self.assertEqual(result[0].normal_relation_parent.pk, self.related1.pk) + self.assertEqual(result[1], self.c2) + self.assertEqual(result[1].normal_relation_parent.pk, self.related2.pk) + self.assertEqual(result[2], self.gc1) + self.assertEqual(result[2].normal_relation_parent.pk, self.related1.pk) + self.assertEqual(result[3], self.gc2) + self.assertEqual(result[3].normal_relation_parent.pk, self.related2.pk) + self.assertEqual(result[4], self.gc3) + self.assertEqual(result[4].normal_relation_parent.pk, self.related1.pk) + self.assertEqual(result[5], self.gc4) + self.assertEqual(result[5].normal_relation_parent.pk, self.related3.pk) + self.assertEqual(result[6], self.c3) + + def test_prefetch_related_works_with_m2m_in_parent(self) -> None: + with self.assertNumQueries(2): + result = list( + InheritanceManagerTestParent.objects.select_subclasses().prefetch_related( + 'normal_many_relation_parent', + ).order_by('pk') + ) + + self.assertEqual(set(result[0].normal_many_relation_parent.all()), {self.related1, self.related2}) + self.assertEqual(set(result[1].normal_many_relation_parent.all()), {self.related3, self.related4}) + self.assertEqual(set(result[2].normal_many_relation_parent.all()), {self.related1, self.related2}) + self.assertEqual(set(result[3].normal_many_relation_parent.all()), {self.related3, self.related4}) + self.assertEqual(set(result[4].normal_many_relation_parent.all()), {self.related1, self.related2}) + self.assertEqual(set(result[5].normal_many_relation_parent.all()), {self.related5, self.related1}) + + def test_prefetch_related_works_with_fk_in_subclass(self) -> None: + with self.assertNumQueries(2): + result = list( + InheritanceManagerTestParent.objects.select_subclasses().prefetch_related( + 'inheritancemanagertestchild1__normal_relation' + ).order_by('pk') + ) + self.assertEqual(result[0], self.c1) + self.assertEqual(result[0].normal_relation.pk, self.related1.pk) + self.assertEqual(result[1], self.c2) + self.assertEqual(result[1].normal_relation.pk, self.related2.pk) + self.assertEqual(result[2], self.gc1) + self.assertEqual(result[2].normal_relation.pk, self.related1.pk) + self.assertEqual(result[3], self.gc2) + self.assertEqual(result[3].normal_relation.pk, self.related2.pk) + self.assertEqual(result[4], self.gc3) + self.assertEqual(result[4].normal_relation.pk, self.related1.pk) + self.assertEqual(result[5], self.gc4) + self.assertEqual(result[5].normal_relation.pk, self.related3.pk) + self.assertEqual(result[6], self.c3) + + def test_prefetch_related_works_with_m2m_in_subclass(self) -> None: + with self.assertNumQueries(2): + result = list( + InheritanceManagerTestParent.objects.select_subclasses().prefetch_related( + 'inheritancemanagertestchild1__normal_many_relation', + ).order_by('pk') + ) + + self.assertEqual(set(result[0].normal_many_relation.all()), {self.related1, self.related2}) + self.assertEqual(set(result[1].normal_many_relation.all()), {self.related3, self.related4}) + self.assertEqual(set(result[2].normal_many_relation.all()), {self.related1, self.related2}) + self.assertEqual(set(result[3].normal_many_relation.all()), {self.related3, self.related4}) + self.assertEqual(set(result[4].normal_many_relation.all()), {self.related1, self.related2}) + self.assertEqual(set(result[5].normal_many_relation.all()), {self.related5, self.related1}) + + def test_prefetch_related_works_with_m2m_to_attr(self) -> None: + with self.assertNumQueries(2): + result = list( + InheritanceManagerTestParent.objects.select_subclasses().prefetch_related( + Prefetch('inheritancemanagertestchild1__normal_many_relation', to_attr='prefetched_many_relation') + ).order_by('pk') + ) + + self.assertEqual(set(result[0].prefetched_many_relation), {self.related1, self.related2}) + self.assertEqual(set(result[1].prefetched_many_relation), {self.related3, self.related4}) + self.assertEqual(set(result[2].prefetched_many_relation), {self.related1, self.related2}) + self.assertEqual(set(result[3].prefetched_many_relation), {self.related3, self.related4}) + self.assertEqual(set(result[4].prefetched_many_relation), {self.related1, self.related2}) + self.assertEqual(set(result[5].prefetched_many_relation), {self.related5, self.related1}) + self.assertFalse(hasattr(result[6], 'prefetched_many_relation'))