Skip to content

Commit bd33718

Browse files
akaariaitimgraham
authored andcommitted
Fixed #23877 -- aggregation's subquery missed target col
Aggregation over subquery produced syntactically incorrect queries in some cases as Django didn't ensure that source expressions of the aggregation were present in the subquery.
1 parent c7fd9b2 commit bd33718

File tree

3 files changed

+81
-7
lines changed

3 files changed

+81
-7
lines changed

django/db/models/sql/compiler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,10 @@ def get_grouping(self, having_group_by, ordering_group_by):
580580
if isinstance(col, (list, tuple)):
581581
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
582582
elif hasattr(col, 'as_sql'):
583-
self.compile(col)
583+
sql, col_params = self.compile(col)
584584
else:
585585
sql = '(%s)' % str(col)
586-
if sql not in seen:
586+
if sql not in seen or col_params:
587587
result.append(sql)
588588
params.extend(col_params)
589589
seen.add(sql)
@@ -604,6 +604,14 @@ def get_grouping(self, having_group_by, ordering_group_by):
604604
sql = '(%s)' % str(extra_select)
605605
result.append(sql)
606606
params.extend(extra_params)
607+
# Finally, add needed group by cols from annotations
608+
for annotation in self.query.annotation_select.values():
609+
cols = annotation.get_group_by_cols()
610+
for col in cols:
611+
sql = '%s.%s' % (qn(col[0]), qn(col[1]))
612+
if sql not in seen:
613+
result.append(sql)
614+
seen.add(sql)
607615

608616
return result, params
609617

django/db/models/sql/query.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,41 @@ def relabeled_clone(self, change_map):
313313
clone.change_aliases(change_map)
314314
return clone
315315

316+
def rewrite_cols(self, annotation, col_cnt):
317+
# We must make sure the inner query has the referred columns in it.
318+
# If we are aggregating over an annotation, then Django uses Ref()
319+
# instances to note this. However, if we are annotating over a column
320+
# of a related model, then it might be that column isn't part of the
321+
# SELECT clause of the inner query, and we must manually make sure
322+
# the column is selected. An example case is:
323+
# .aggregate(Sum('author__awards'))
324+
# Resolving this expression results in a join to author, but there
325+
# is no guarantee the awards column of author is in the select clause
326+
# of the query. Thus we must manually add the column to the inner
327+
# query.
328+
orig_exprs = annotation.get_source_expressions()
329+
new_exprs = []
330+
for expr in orig_exprs:
331+
if isinstance(expr, Ref):
332+
# Its already a Ref to subquery (see resolve_ref() for
333+
# details)
334+
new_exprs.append(expr)
335+
elif isinstance(expr, Col):
336+
# Reference to column. Make sure the referenced column
337+
# is selected.
338+
col_cnt += 1
339+
col_alias = '__col%d' % col_cnt
340+
self.annotation_select[col_alias] = expr
341+
self.append_annotation_mask([col_alias])
342+
new_exprs.append(Ref(col_alias, expr))
343+
else:
344+
# Some other expression not referencing database values
345+
# directly. Its subexpression might contain Cols.
346+
new_expr, col_cnt = self.rewrite_cols(expr, col_cnt)
347+
new_exprs.append(new_expr)
348+
annotation.set_source_expressions(new_exprs)
349+
return annotation, col_cnt
350+
316351
def get_aggregation(self, using, added_aggregate_names):
317352
"""
318353
Returns the dictionary with the values of the existing aggregations.
@@ -350,11 +385,11 @@ def get_aggregation(self, using, added_aggregate_names):
350385
relabels[None] = 'subquery'
351386
# Remove any aggregates marked for reduction from the subquery
352387
# and move them to the outer AggregateQuery.
353-
for alias, annotation in inner_query.annotation_select.items():
354-
if annotation.is_summary:
355-
# The annotation is already referring the subquery alias, so we
356-
# just need to move the annotation to the outer query.
357-
outer_query.annotations[alias] = annotation.relabeled_clone(relabels)
388+
col_cnt = 0
389+
for alias, expression in inner_query.annotation_select.items():
390+
if expression.is_summary:
391+
expression, col_cnt = inner_query.rewrite_cols(expression, col_cnt)
392+
outer_query.annotations[alias] = expression.relabeled_clone(relabels)
358393
del inner_query.annotation_select[alias]
359394
try:
360395
outer_query.add_subquery(inner_query, using)
@@ -1495,6 +1530,10 @@ def resolve_ref(self, name, allow_joins, reuse, summarize):
14951530
raise FieldError("Joined field references are not permitted in this query")
14961531
if name in self.annotations:
14971532
if summarize:
1533+
# Summarize currently means we are doing an aggregate() query
1534+
# which is executed as a wrapped subquery if any of the
1535+
# aggregate() elements reference an existing annotation. In
1536+
# that case we need to return a Ref to the subquery's annotation.
14981537
return Ref(name, self.annotation_select[name])
14991538
else:
15001539
return self.annotation_select[name]

tests/aggregation_regress/tests.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,3 +1168,30 @@ def test_existing_join_not_promoted(self):
11681168
def test_non_nullable_fk_not_promoted(self):
11691169
qs = Book.objects.annotate(Count('contact__name'))
11701170
self.assertIn(' INNER JOIN ', str(qs.query))
1171+
1172+
1173+
class AggregationOnRelationTest(TestCase):
1174+
def setUp(self):
1175+
self.a = Author.objects.create(name='Anssi', age=33)
1176+
self.p = Publisher.objects.create(name='Manning', num_awards=3)
1177+
Book.objects.create(isbn='asdf', name='Foo', pages=10, rating=0.1, price="0.0",
1178+
contact=self.a, publisher=self.p, pubdate=datetime.date.today())
1179+
1180+
def test_annotate_on_relation(self):
1181+
qs = Book.objects.annotate(avg_price=Avg('price'), publisher_name=F('publisher__name'))
1182+
self.assertEqual(qs[0].avg_price, 0.0)
1183+
self.assertEqual(qs[0].publisher_name, "Manning")
1184+
1185+
def test_aggregate_on_relation(self):
1186+
# A query with an existing annotation aggregation on a relation should
1187+
# succeed.
1188+
qs = Book.objects.annotate(avg_price=Avg('price')).aggregate(
1189+
publisher_awards=Sum('publisher__num_awards')
1190+
)
1191+
self.assertEqual(qs['publisher_awards'], 3)
1192+
Book.objects.create(isbn='asdf', name='Foo', pages=10, rating=0.1, price="0.0",
1193+
contact=self.a, publisher=self.p, pubdate=datetime.date.today())
1194+
qs = Book.objects.annotate(avg_price=Avg('price')).aggregate(
1195+
publisher_awards=Sum('publisher__num_awards')
1196+
)
1197+
self.assertEqual(qs['publisher_awards'], 6)

0 commit comments

Comments
 (0)