Skip to content

Add support for prefetch_related to InheritanceManager #639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Changelog
To be released
------------------
- Add support for `Python 3.13` (GH-#628)
- Add support for `prefetch_related` to `InheritanceManager`

5.0.0 (2024-09-01)
------------------
Expand Down
8 changes: 8 additions & 0 deletions docs/managers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ If you don't explicitly call ``select_subclasses()`` or ``get_subclass()``,
an ``InheritanceManager`` behaves identically to a normal ``Manager``; so
it's safe to use as your default manager for the model.

``InheritanceManager`` supports ``prefetch_related``, even in subclasses:

.. code-block:: python

places = Place.objects.select_subclasses().prefetch_related('bar__manager')
# every Bar instance in places will have it's manager relation prefetched


.. _contributed by Jeff Elmore: https://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/

JoinQueryset
Expand Down
115 changes: 94 additions & 21 deletions model_utils/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@

from django.core.exceptions import ObjectDoesNotExist
from django.db import connection, models
from django.db.models import Q
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import ModelIterable, QuerySet
from django.db.models.query import (
ModelIterable,
Prefetch,
QuerySet,
prefetch_related_objects,
)
from django.db.models.sql.datastructures import Join

ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
Expand Down Expand Up @@ -65,6 +71,8 @@ def __iter__(self):

class InheritanceQuerySetMixin(Generic[ModelT]):

_prefetch_related_lookups: Sequence[str | Prefetch]
_result_cache: list[ModelT]
model: type[ModelT]
subclasses: Sequence[str]

Expand Down Expand Up @@ -105,6 +113,79 @@ def select_subclasses(self, *subclasses: str | type[models.Model]) -> Inheritanc
new_qs.subclasses = selected_subclasses
return new_qs

def _prefetch_related_objects(self):
# Step 1: Find the base objects
# self._result_cache contains the subclasses as returned by InheritanceIterable
# walk up the path_to_parent to get to the parent model for each
_base_objs = []
sub_obj: ModelT
for sub_obj in self._result_cache:
for p in sub_obj._meta.get_path_to_parent(self.model):
sub_obj = getattr(sub_obj, p.join_field.name)
_base_objs.append(sub_obj)

# Step 2: Prefetch using the base objects
# This satisfies the requirement of prefetch_related_objects that the list be homogeneous
# This allows the user to use prefetch_related(subclass__subclass_relation)
# Because InheritanceIterable transforms the result into "subclass", then subclass_relation will
# be prefetched on that subclass object, as expected
prefetch_related_objects(_base_objs, *self._prefetch_related_lookups)

# Step 3: Copy down the prefetched objects
# Assuming we have the inheritance C extends B extends A
# If a relation is prefetched at B, we must put those prefetched objects into the C's
# _prefetched_objects_cache as well, so that when a C object is returned (which is obtained
# by InheritanceIterable via base_obj.b.c) and the user does sub_obj.m2m_field.all()
# then Django's ManyRelatedManager will look into base_obj.b.c._prefetched_objects_cache
# but prefetch_related_objects above has put the prefetched objects into base_obj.b._prefetched_objects_cache
# The same goes for _state.fields_cache, which is used by ForeignKeys
# ForeignKeys already make an attempt to look at the parent's fields_cache, but it only works for one level
# Additionally, copy any to_attr prefetches down as well
for sub_obj, base_obj in zip(self._result_cache, _base_objs):
# get the base caches or create a new a blank one if there isn't any
prefetch_cache = getattr(base_obj, '_prefetched_objects_cache', None)
if prefetch_cache is not None:
prefetch_cache = dict(prefetch_cache)
else:
prefetch_cache = {}
fields_cache = dict(base_obj._state.fields_cache)

current = base_obj
current_path = []
prefetch_attrs = {}
for p in sub_obj._meta.get_path_from_parent(self.model):
join_field_name = p.join_field.name
current_path.append(join_field_name)
current = getattr(current, join_field_name)
child_cache: dict | None = getattr(current, '_prefetched_objects_cache', None)
if child_cache is not None:
# The child already has its own cache, add it to the running list of prefetches
prefetch_cache.update(child_cache)
if prefetch_cache:
# If we have something prefetched at this level or above, put it in this sub_obj
current._prefetched_objects_cache = prefetch_cache
# prepare a fresh dict for the next level down
prefetch_cache = dict(prefetch_cache)

child_fields_cache = current._state.fields_cache
if child_fields_cache:
fields_cache.update(child_fields_cache)
if fields_cache:
current._state.fields_cache = fields_cache
fields_cache = dict(fields_cache)

for prefetch in self._prefetch_related_lookups:
if isinstance(prefetch, Prefetch) and prefetch.to_attr:
prefetch_path = prefetch.prefetch_to.split(LOOKUP_SEP)[:-1]
if current_path == prefetch_path:
# The prefetch was at this level exactly, get the prefetch from the object
prefetch_attrs[prefetch] = getattr(current, prefetch.to_attr)
elif current_path[:len(prefetch_path)] == prefetch_path:
# the prefetch was for a parent of this one, get it from the running cache
setattr(current, prefetch.to_attr, prefetch_attrs[prefetch])

self._prefetch_done = True

def _chain(self, **kwargs: object) -> InheritanceQuerySet[ModelT]:
update = {}
for name in ['subclasses', '_annotated']:
Expand Down Expand Up @@ -165,18 +246,10 @@ def _get_ancestors_path(self, model: type[models.Model]) -> str:
raise ValueError(
f"{model!r} is not a subclass of {self.model!r}")

ancestry: list[str] = []
# should be a OneToOneField or None
parent_link = model._meta.get_ancestor_link(self.model)

while parent_link is not None:
related = parent_link.remote_field
ancestry.insert(0, related.get_accessor_name())

parent_model = related.model
parent_link = parent_model._meta.get_ancestor_link(self.model)

return LOOKUP_SEP.join(ancestry)
return LOOKUP_SEP.join(
p.join_field.get_accessor_name()
for p in model._meta.get_path_from_parent(self.model)
)

def _get_sub_obj_recurse(self, obj: models.Model, s: str) -> ModelT | None:
rel, _, s = s.partition(LOOKUP_SEP)
Expand Down Expand Up @@ -212,18 +285,18 @@ def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
# Due to https://code.djangoproject.com/ticket/16572, we
# can't really do this for anything other than children (ie,
# no grandchildren+).
where_queries = []
conditions = []
for model in models:
where_queries.append('(' + ' AND '.join([
'"{}"."{}" IS NOT NULL'.format(
model._meta.db_table,
field.column,
) for field in model._meta.parents.values()
]) + ')')
path_from_parent = LOOKUP_SEP.join(
p.join_field.get_accessor_name() for p in model._meta.get_path_from_parent(self.model)
)
conditions.append(
(path_from_parent + LOOKUP_SEP + 'isnull', False)
)

return cast(
'InheritanceQuerySet[ModelT]',
self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
self.select_subclasses(*models).filter(Q(*conditions, _connector=Q.OR))
)


Expand Down
8 changes: 8 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class InheritanceManagerTestParent(models.Model):
related_self = models.OneToOneField(
"self", related_name="imtests_self", null=True,
on_delete=models.CASCADE)
normal_relation_parent = models.ForeignKey('InheritanceManagerNonChild', null=True, on_delete=models.CASCADE)
normal_many_relation_parent = models.ManyToManyField('InheritanceManagerNonChild')
objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager()

def __str__(self) -> str:
Expand All @@ -56,6 +58,8 @@ def __str__(self) -> str:
class InheritanceManagerTestChild1(InheritanceManagerTestParent):
non_related_field_using_descriptor_2 = models.FileField(upload_to="test")
normal_field_2 = models.TextField()
normal_relation = models.ForeignKey('InheritanceManagerNonChild', null=True, on_delete=models.CASCADE)
normal_many_relation = models.ManyToManyField('InheritanceManagerNonChild')
objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager()


Expand Down Expand Up @@ -95,6 +99,10 @@ class InheritanceManagerTestChild4(InheritanceManagerTestParent):
parent_link=True, on_delete=models.CASCADE)


class InheritanceManagerNonChild(models.Model):
name = models.CharField(max_length=255)


class TimeStamp(TimeStampedModel):
test_field = models.PositiveSmallIntegerField(default=0)

Expand Down
143 changes: 143 additions & 0 deletions tests/test_managers/test_inheritance_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import TYPE_CHECKING

from django.db import models
from django.db.models import Prefetch
from django.test import TestCase

from model_utils.managers import InheritanceManager
from tests.models import (
InheritanceManagerNonChild,
InheritanceManagerTestChild1,
InheritanceManagerTestChild2,
InheritanceManagerTestChild3,
Expand Down Expand Up @@ -541,3 +543,144 @@ def test_annotate_with_named_arguments_before_select_subclasses(self) -> None:
def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self) -> None:
qs = InheritanceManagerTestParent.objects.select_subclasses()
self.assertEqual(qs.subclasses, qs._clone().subclasses)


class InheritanceManagerPrefetchForeignKeyTests(TestCase):

def setUp(self) -> None:
self.related1 = InheritanceManagerNonChild.objects.create()
self.related2 = InheritanceManagerNonChild.objects.create()
self.related3 = InheritanceManagerNonChild.objects.create()
self.related4 = InheritanceManagerNonChild.objects.create()
self.related5 = InheritanceManagerNonChild.objects.create()
self.c1 = InheritanceManagerTestChild1.objects.create(
normal_relation=self.related1,
normal_relation_parent=self.related1,
)
self.c1.normal_many_relation.set([self.related1, self.related2])
self.c1.normal_many_relation_parent.set([self.related1, self.related2])

self.c2 = InheritanceManagerTestChild1.objects.create(
normal_relation=self.related2,
normal_relation_parent=self.related2,
)
self.c2.normal_many_relation.set([self.related3, self.related4])
self.c2.normal_many_relation_parent.set([self.related3, self.related4])

self.gc1 = InheritanceManagerTestGrandChild1.objects.create(
normal_relation=self.related1,
normal_relation_parent=self.related1,
)
self.gc1.normal_many_relation.set([self.related1, self.related2])
self.gc1.normal_many_relation_parent.set([self.related1, self.related2])

self.gc2 = InheritanceManagerTestGrandChild1.objects.create(
normal_relation=self.related2,
normal_relation_parent=self.related2,
)
self.gc2.normal_many_relation.set([self.related3, self.related4])
self.gc2.normal_many_relation_parent.set([self.related3, self.related4])

self.gc3 = InheritanceManagerTestGrandChild1_2.objects.create(
normal_relation=self.related1,
normal_relation_parent=self.related1,
)
self.gc3.normal_many_relation.set([self.related1, self.related2])
self.gc3.normal_many_relation_parent.set([self.related1, self.related2])

self.gc4 = InheritanceManagerTestGrandChild1.objects.create(
normal_relation=self.related3,
normal_relation_parent=self.related3,
)
self.gc4.normal_many_relation.set([self.related5, self.related1])
self.gc4.normal_many_relation_parent.set([self.related5, self.related1])

self.c3 = InheritanceManagerTestChild2.objects.create()

def test_prefetch_related_works_with_fk_in_parent(self) -> None:
with self.assertNumQueries(2):
result = list(
InheritanceManagerTestParent.objects.select_subclasses().prefetch_related(
'normal_relation_parent'
).order_by('pk')
)
self.assertEqual(result[0], self.c1)
self.assertEqual(result[0].normal_relation_parent.pk, self.related1.pk)
self.assertEqual(result[1], self.c2)
self.assertEqual(result[1].normal_relation_parent.pk, self.related2.pk)
self.assertEqual(result[2], self.gc1)
self.assertEqual(result[2].normal_relation_parent.pk, self.related1.pk)
self.assertEqual(result[3], self.gc2)
self.assertEqual(result[3].normal_relation_parent.pk, self.related2.pk)
self.assertEqual(result[4], self.gc3)
self.assertEqual(result[4].normal_relation_parent.pk, self.related1.pk)
self.assertEqual(result[5], self.gc4)
self.assertEqual(result[5].normal_relation_parent.pk, self.related3.pk)
self.assertEqual(result[6], self.c3)

def test_prefetch_related_works_with_m2m_in_parent(self) -> None:
with self.assertNumQueries(2):
result = list(
InheritanceManagerTestParent.objects.select_subclasses().prefetch_related(
'normal_many_relation_parent',
).order_by('pk')
)

self.assertEqual(set(result[0].normal_many_relation_parent.all()), {self.related1, self.related2})
self.assertEqual(set(result[1].normal_many_relation_parent.all()), {self.related3, self.related4})
self.assertEqual(set(result[2].normal_many_relation_parent.all()), {self.related1, self.related2})
self.assertEqual(set(result[3].normal_many_relation_parent.all()), {self.related3, self.related4})
self.assertEqual(set(result[4].normal_many_relation_parent.all()), {self.related1, self.related2})
self.assertEqual(set(result[5].normal_many_relation_parent.all()), {self.related5, self.related1})

def test_prefetch_related_works_with_fk_in_subclass(self) -> None:
with self.assertNumQueries(2):
result = list(
InheritanceManagerTestParent.objects.select_subclasses().prefetch_related(
'inheritancemanagertestchild1__normal_relation'
).order_by('pk')
)
self.assertEqual(result[0], self.c1)
self.assertEqual(result[0].normal_relation.pk, self.related1.pk)
self.assertEqual(result[1], self.c2)
self.assertEqual(result[1].normal_relation.pk, self.related2.pk)
self.assertEqual(result[2], self.gc1)
self.assertEqual(result[2].normal_relation.pk, self.related1.pk)
self.assertEqual(result[3], self.gc2)
self.assertEqual(result[3].normal_relation.pk, self.related2.pk)
self.assertEqual(result[4], self.gc3)
self.assertEqual(result[4].normal_relation.pk, self.related1.pk)
self.assertEqual(result[5], self.gc4)
self.assertEqual(result[5].normal_relation.pk, self.related3.pk)
self.assertEqual(result[6], self.c3)

def test_prefetch_related_works_with_m2m_in_subclass(self) -> None:
with self.assertNumQueries(2):
result = list(
InheritanceManagerTestParent.objects.select_subclasses().prefetch_related(
'inheritancemanagertestchild1__normal_many_relation',
).order_by('pk')
)

self.assertEqual(set(result[0].normal_many_relation.all()), {self.related1, self.related2})
self.assertEqual(set(result[1].normal_many_relation.all()), {self.related3, self.related4})
self.assertEqual(set(result[2].normal_many_relation.all()), {self.related1, self.related2})
self.assertEqual(set(result[3].normal_many_relation.all()), {self.related3, self.related4})
self.assertEqual(set(result[4].normal_many_relation.all()), {self.related1, self.related2})
self.assertEqual(set(result[5].normal_many_relation.all()), {self.related5, self.related1})

def test_prefetch_related_works_with_m2m_to_attr(self) -> None:
with self.assertNumQueries(2):
result = list(
InheritanceManagerTestParent.objects.select_subclasses().prefetch_related(
Prefetch('inheritancemanagertestchild1__normal_many_relation', to_attr='prefetched_many_relation')
).order_by('pk')
)

self.assertEqual(set(result[0].prefetched_many_relation), {self.related1, self.related2})
self.assertEqual(set(result[1].prefetched_many_relation), {self.related3, self.related4})
self.assertEqual(set(result[2].prefetched_many_relation), {self.related1, self.related2})
self.assertEqual(set(result[3].prefetched_many_relation), {self.related3, self.related4})
self.assertEqual(set(result[4].prefetched_many_relation), {self.related1, self.related2})
self.assertEqual(set(result[5].prefetched_many_relation), {self.related5, self.related1})
self.assertFalse(hasattr(result[6], 'prefetched_many_relation'))