Skip to content

Commit

Permalink
Group queries for SlugRelatedField many serializers
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-escolano committed Jan 5, 2024
1 parent d6ca95f commit b72027f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
21 changes: 16 additions & 5 deletions rest_framework/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from urllib import parse

from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
from django.db.models import Manager
from django.db.models import F, Manager
from django.db.models.query import QuerySet
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
from django.utils.encoding import smart_str, uri_to_iri
Expand Down Expand Up @@ -458,15 +458,26 @@ def __init__(self, slug_field=None, **kwargs):
self.slug_field = slug_field
super().__init__(**kwargs)

def to_internal_value(self, data):
def to_many_internal_value(self, data):
queryset = self.get_queryset()
try:
return queryset.get(**{self.slug_field: data})
except ObjectDoesNotExist:
self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(data))
result = (
queryset
.filter(**{self.slug_field + "__in": data})
.annotate(_slug_field_value=F(self.slug_field))
.all()
)
slugs = [item._slug_field_value for item in result]
for item in data:
if item not in slugs:
self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(item))
return result
except (TypeError, ValueError):
self.fail('invalid')

def to_internal_value(self, data):
return self.to_many_internal_value([data])[0]

def to_representation(self, obj):
slug = self.slug_field
if "__" in slug:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_relations_slug.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ def test_reverse_foreign_key_create(self):
]
assert serializer.data == expected

def test_reverse_foreign_key_create_grouped_queries(self):
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data)
with self.assertNumQueries(1):
assert serializer.is_valid()

def test_foreign_key_update_with_invalid_null(self):
data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
Expand Down
10 changes: 8 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def all(self):
return list(self.items)

def filter(self, **lookup):
return MockQueryset(
return MockQueryset([
item
for item in self.items
if all([
Expand All @@ -44,7 +44,13 @@ def filter(self, **lookup):
else attrgetter(key.replace('__', '.'))(item) == value
for key, value in lookup.items()
])
)
])

def annotate(self, **kwargs):
for key, value in kwargs.items():
for item in self.items:
setattr(item, key, attrgetter(value.name.replace('__', '.'))(item))
return self


class BadType:
Expand Down

0 comments on commit b72027f

Please sign in to comment.