Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve symbols required by join during partial aggregation pushdown #635

Open
wants to merge 1 commit into
base: sprint-59
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.facebook.presto.SystemSessionProperties.isPushAggregationThroughJoin;
import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL;
Expand All @@ -60,6 +60,8 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Sets.intersection;
import static java.util.Objects.requireNonNull;

public class PartialAggregationPushDown
Expand Down Expand Up @@ -192,18 +194,29 @@ else if (allAggregationsOn(node.getAggregations(), child.getRight().getOutputSym

private PlanNode pushPartialToLeftChild(AggregationNode node, JoinNode child, RewriteContext<Void> context)
{
List<Symbol> groupingSet = getPushedDownGroupingSet(node, child, ImmutableSet.copyOf(child.getLeft().getOutputSymbols()));
Set<Symbol> joinLeftChildSymbols = ImmutableSet.copyOf(child.getLeft().getOutputSymbols());
List<Symbol> groupingSet = getPushedDownGroupingSet(node, joinLeftChildSymbols, intersection(getJoinRequiredSymbols(child), joinLeftChildSymbols));
AggregationNode pushedAggregation = replaceAggregationSource(node, child.getLeft(), child.getCriteria(), groupingSet, context);
return pushPartialToJoin(pushedAggregation, child, pushedAggregation, context.rewrite(child.getRight()), child.getRight().getOutputSymbols());
}

private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, RewriteContext<Void> context)
{
List<Symbol> groupingSet = getPushedDownGroupingSet(node, child, ImmutableSet.copyOf(child.getRight().getOutputSymbols()));
Set<Symbol> joinRightChildSymbols = ImmutableSet.copyOf(child.getRight().getOutputSymbols());
List<Symbol> groupingSet = getPushedDownGroupingSet(node, joinRightChildSymbols, intersection(getJoinRequiredSymbols(child), joinRightChildSymbols));
AggregationNode pushedAggregation = replaceAggregationSource(node, child.getRight(), child.getCriteria(), groupingSet, context);
return pushPartialToJoin(pushedAggregation, child, context.rewrite(child.getLeft()), pushedAggregation, child.getLeft().getOutputSymbols());
}

private Set<Symbol> getJoinRequiredSymbols(JoinNode node)
{
return ImmutableSet.<Symbol>builder()
.addAll(node.getCriteria().stream().map(EquiJoinClause::getLeft).collect(toImmutableSet()))
.addAll(node.getCriteria().stream().map(EquiJoinClause::getRight).collect(toImmutableSet()))
.addAll(node.getFilter().map(DependencyExtractor::extractUnique).orElse(ImmutableSet.of()))
.build();
}

private PlanNode pushPartialToJoin(
AggregationNode pushedAggregation,
JoinNode child,
Expand Down Expand Up @@ -265,28 +278,21 @@ private boolean allAggregationsOn(Map<Symbol, Aggregation> aggregations, List<Sy
return outputSymbols.containsAll(inputs);
}

private List<Symbol> getPushedDownGroupingSet(AggregationNode aggregation, JoinNode join, Set<Symbol> availableSymbols)
private List<Symbol> getPushedDownGroupingSet(AggregationNode aggregation, Set<Symbol> availableSymbols, Set<Symbol> requiredJoinSymbols)
{
List<Symbol> groupingSet = Iterables.getOnlyElement(aggregation.getGroupingSets());
Set<Symbol> joinKeys = Stream.concat(
join.getCriteria().stream().map(EquiJoinClause::getLeft),
join.getCriteria().stream().map(EquiJoinClause::getRight)
).collect(Collectors.toSet());

// keep symbols that are either directly from the join's child (availableSymbols) or there is
// an equality in join condition to a symbol for the join child
// keep symbols that are directly from the join's child (availableSymbols)
List<Symbol> pushedDownGroupingSet = groupingSet.stream()
.filter(symbol -> joinKeys.contains(symbol) || availableSymbols.contains(symbol))
.filter(availableSymbols::contains)
.collect(Collectors.toList());

if (pushedDownGroupingSet.size() != groupingSet.size() || pushedDownGroupingSet.isEmpty()) {
// If we dropped some symbol, we have to add all join key columns to the grouping set
Set<Symbol> existingSymbols = ImmutableSet.copyOf(pushedDownGroupingSet);
// add missing required join symbols to grouping set
Set<Symbol> existingSymbols = new HashSet<>(pushedDownGroupingSet);
requiredJoinSymbols.stream()
.filter(existingSymbols::add)
.forEach(pushedDownGroupingSet::add);

join.getCriteria().stream()
.filter(equiJoinClause -> !existingSymbols.contains(equiJoinClause.getLeft()) && !existingSymbols.contains(equiJoinClause.getRight()))
.forEach(joinClause -> pushedDownGroupingSet.add(joinClause.getLeft()));
}
return pushedDownGroupingSet;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1851,19 +1851,19 @@ public void testGrouping()
throws Exception
{
assertQuery(
"SELECT a, b as t, sum(c), grouping(a, b) + grouping(a) " +
"FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) " +
"GROUP BY GROUPING SETS ( (a), (b)) " +
"ORDER BY grouping(b) ASC",
"VALUES (NULL, 'j', 11, 3), (NULL, 'l', 7, 3), ('h', NULL, 11, 1), ('k', NULL, 7, 1)");
"SELECT a, b as t, sum(c), grouping(a, b) + grouping(a) " +
"FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) " +
"GROUP BY GROUPING SETS ( (a), (b)) " +
"ORDER BY grouping(b) ASC",
"VALUES (NULL, 'j', 11, 3), (NULL, 'l', 7, 3), ('h', NULL, 11, 1), ('k', NULL, 7, 1)");

assertQuery(
"SELECT a, sum(b), grouping(a) FROM (VALUES ('h', 11, 0), ('k', 7, 0)) AS t (a, b, c) GROUP BY GROUPING SETS (a)",
"VALUES ('h', 11, 0), ('k', 7, 0)");
"SELECT a, sum(b), grouping(a) FROM (VALUES ('h', 11, 0), ('k', 7, 0)) AS t (a, b, c) GROUP BY GROUPING SETS (a)",
"VALUES ('h', 11, 0), ('k', 7, 0)");

assertQuery(
"SELECT a, b, sum(c), grouping(a, b) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7) ) AS t (a, b, c) GROUP BY GROUPING SETS ( (a), (b)) HAVING grouping(a, b) > 1 ",
"VALUES (NULL, 'j', 11, 2), (NULL, 'l', 7, 2)");
"SELECT a, b, sum(c), grouping(a, b) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7) ) AS t (a, b, c) GROUP BY GROUPING SETS ( (a), (b)) HAVING grouping(a, b) > 1 ",
"VALUES (NULL, 'j', 11, 2), (NULL, 'l', 7, 2)");

assertQuery("SELECT a, grouping(a) * 1.0 FROM (VALUES (1) ) AS t (a) GROUP BY a",
"VALUES (1, 0.0)");
Expand All @@ -1872,7 +1872,7 @@ public void testGrouping()
"VALUES (1, 0, 0)");

assertQuery("SELECT grouping(a) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) GROUP BY GROUPING SETS (a,c), c*2",
"VALUES (0), (1), (0), (1)");
"VALUES (0), (1), (0), (1)");
}

@Test
Expand Down Expand Up @@ -1907,23 +1907,23 @@ public void testGroupingInWindowFunction()
throws Exception
{
assertQuery(
"SELECT orderkey, custkey, sum(totalprice), grouping(orderkey)+grouping(custkey) as g, " +
" rank() OVER (PARTITION BY grouping(orderkey)+grouping(custkey), " +
" CASE WHEN grouping(orderkey) = 0 THEN custkey END ORDER BY orderkey ASC) as r " +
"FROM orders " +
"GROUP BY ROLLUP (orderkey, custkey) " +
"ORDER BY orderkey, custkey " +
"LIMIT 10",
"VALUES (1, 370, 172799.49, 0, 1), " +
" (1, NULL, 172799.49, 1, 1), " +
" (2, 781, 38426.09, 0, 1), " +
" (2, NULL, 38426.09, 1, 2), " +
" (3, 1234, 205654.30, 0, 1), " +
" (3, NULL, 205654.30, 1, 3), " +
" (4, 1369, 56000.91, 0, 1), " +
" (4, NULL, 56000.91, 1, 4), " +
" (5, 445, 105367.67, 0, 1), " +
" (5, NULL, 105367.67, 1, 5)");
"SELECT orderkey, custkey, sum(totalprice), grouping(orderkey)+grouping(custkey) as g, " +
" rank() OVER (PARTITION BY grouping(orderkey)+grouping(custkey), " +
" CASE WHEN grouping(orderkey) = 0 THEN custkey END ORDER BY orderkey ASC) as r " +
"FROM orders " +
"GROUP BY ROLLUP (orderkey, custkey) " +
"ORDER BY orderkey, custkey " +
"LIMIT 10",
"VALUES (1, 370, 172799.49, 0, 1), " +
" (1, NULL, 172799.49, 1, 1), " +
" (2, 781, 38426.09, 0, 1), " +
" (2, NULL, 38426.09, 1, 2), " +
" (3, 1234, 205654.30, 0, 1), " +
" (3, NULL, 205654.30, 1, 3), " +
" (4, 1369, 56000.91, 0, 1), " +
" (4, NULL, 56000.91, 1, 4), " +
" (5, 445, 105367.67, 0, 1), " +
" (5, NULL, 105367.67, 1, 5)");
}

@Test
Expand All @@ -1938,54 +1938,54 @@ public void testGroupingInTableSubquery()

// Inner query has a single GROUP BY and outer query has GROUPING SETS
assertQuery(
"SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey), g " +
"FROM " +
" (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " +
" FROM orders " +
" GROUP BY orderkey, custkey " +
" ORDER BY agg_price ASC " +
" LIMIT 5) as t " +
"GROUP BY GROUPING SETS ((orderkey, custkey), g) " +
"ORDER BY outer_sum",
"VALUES (35271, 334, 874.89, 0, NULL), " +
" (28647, 1351, 924.33, 0, NULL), " +
" (58145, 862, 929.03, 0, NULL), " +
" (8354, 634, 974.04, 0, NULL), " +
" (37415, 301, 986.63, 0, NULL), " +
" (NULL, NULL, 4688.92, 3, 0)");
"SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey), g " +
"FROM " +
" (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " +
" FROM orders " +
" GROUP BY orderkey, custkey " +
" ORDER BY agg_price ASC " +
" LIMIT 5) as t " +
"GROUP BY GROUPING SETS ((orderkey, custkey), g) " +
"ORDER BY outer_sum",
"VALUES (35271, 334, 874.89, 0, NULL), " +
" (28647, 1351, 924.33, 0, NULL), " +
" (58145, 862, 929.03, 0, NULL), " +
" (8354, 634, 974.04, 0, NULL), " +
" (37415, 301, 986.63, 0, NULL), " +
" (NULL, NULL, 4688.92, 3, 0)");

// Inner query has GROUPING SETS and outer query has GROUP BY
assertQuery(
"SELECT orderkey, custkey, g, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " +
"FROM " +
" (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " +
" FROM orders " +
" GROUP BY GROUPING SETS ((custkey), (orderkey)) " +
" ORDER BY agg_price ASC " +
" LIMIT 5) as t " +
"GROUP BY orderkey, custkey, g",
"VALUES (28647, NULL, 2, 924.33, 0), " +
" (8354, NULL, 2, 974.04, 0), " +
" (37415, NULL, 2, 986.63, 0), " +
" (58145, NULL, 2, 929.03, 0), " +
" (35271, NULL, 2, 874.89, 0)");
"SELECT orderkey, custkey, g, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " +
"FROM " +
" (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " +
" FROM orders " +
" GROUP BY GROUPING SETS ((custkey), (orderkey)) " +
" ORDER BY agg_price ASC " +
" LIMIT 5) as t " +
"GROUP BY orderkey, custkey, g",
"VALUES (28647, NULL, 2, 924.33, 0), " +
" (8354, NULL, 2, 974.04, 0), " +
" (37415, NULL, 2, 986.63, 0), " +
" (58145, NULL, 2, 929.03, 0), " +
" (35271, NULL, 2, 874.89, 0)");

// Inner query has GROUPING SETS but no grouping and outer query has a simple GROUP BY
assertQuery(
"SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " +
"FROM " +
" (SELECT orderkey, custkey, sum(totalprice) as agg_price " +
" FROM orders " +
" GROUP BY GROUPING SETS ((custkey), (orderkey)) " +
" ORDER BY agg_price ASC NULLS FIRST) as t " +
"GROUP BY orderkey, custkey " +
"ORDER BY outer_sum ASC NULLS FIRST " +
"LIMIT 5",
"VALUES (35271, NULL, 874.89, 0), " +
" (28647, NULL, 924.33, 0), " +
" (58145, NULL, 929.03, 0), " +
" (8354, NULL, 974.04, 0), " +
" (37415, NULL, 986.63, 0)");
"SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " +
"FROM " +
" (SELECT orderkey, custkey, sum(totalprice) as agg_price " +
" FROM orders " +
" GROUP BY GROUPING SETS ((custkey), (orderkey)) " +
" ORDER BY agg_price ASC NULLS FIRST) as t " +
"GROUP BY orderkey, custkey " +
"ORDER BY outer_sum ASC NULLS FIRST " +
"LIMIT 5",
"VALUES (35271, NULL, 874.89, 0), " +
" (28647, NULL, 924.33, 0), " +
" (58145, NULL, 929.03, 0), " +
" (8354, NULL, 974.04, 0), " +
" (37415, NULL, 986.63, 0)");
}

@Test
Expand Down Expand Up @@ -6107,11 +6107,11 @@ public void testChainedUnionsWithOrder()
public void testUnionWithTopN()
{
assertQuery("SELECT * FROM (" +
" SELECT regionkey FROM nation " +
" UNION ALL " +
" SELECT nationkey FROM nation" +
") t(a) " +
"ORDER BY a LIMIT 1",
" SELECT regionkey FROM nation " +
" UNION ALL " +
" SELECT nationkey FROM nation" +
") t(a) " +
"ORDER BY a LIMIT 1",
"SELECT 0");
}

Expand Down Expand Up @@ -7243,8 +7243,8 @@ public void testCorrelatedScalarSubqueriesWithScalarAggregation()

//count in subquery
assertQuery("SELECT * " +
"FROM (VALUES (0),( 1), (2), (7)) as v1(c1) " +
"WHERE v1.c1 > (SELECT count(c1) from (VALUES (0),( 1), (2)) as v2(c1) WHERE v1.c1 = v2.c1)",
"FROM (VALUES (0),( 1), (2), (7)) as v1(c1) " +
"WHERE v1.c1 > (SELECT count(c1) from (VALUES (0),( 1), (2)) as v2(c1) WHERE v1.c1 = v2.c1)",
"VALUES (2), (7)");
}

Expand Down Expand Up @@ -8921,6 +8921,16 @@ public void testAggregationPushedBelowOuterJoin()
"VALUES 24");
}

@Test
public void testPartialAggregationPushDown()
{
// pushed down aggregation needs to preserve symbols required by join filter and equi-condition
assertQuery("" +
" SELECT orders.custkey AS custkey, orders.orderstatus AS orderstatus " +
" FROM orders JOIN lineitem ON lineitem.orderkey = orders.orderkey AND orders.orderkey = lineitem.partkey AND lineitem.orderkey + 1 = orders.orderkey + 1" +
" GROUP BY orders.custkey, orders.orderstatus");
}

@Test
public void testDefaultDecimalLiteralSwitch()
throws Exception
Expand Down