Skip to content

Commit

Permalink
[enhancement](Nereids) support 4 phases distinct aggregate with full …
Browse files Browse the repository at this point in the history
…distribution (apache#35871)

The origin implementation of 4 phases distinct aggregate only support the pattern which not contains `group by`, and only one distinct aggregate function

for example:
```sql
select count(distinct sex), sum(age)
from student
```

This pr complement the 4 phases distinct aggregate with full distribution, to avoid data skew in the `group by`.

for example
```sql
select sex, sum(distinct age)
from student
group by sex;
```
The sex only contains two distinct values, `male` and `female`, and the table store millions rows.
Shuffle by the `sex` cause the data skew and lots of instances process empty rows.

The 4 phase aggregate shuffle `sex, age` to distinct rows first, so more instances can do parallel distinct, the plan shape will like this:
```

PhysicalAggregate(groupBy=[sex], output=[sex, sum(partial_sum(age))], mode=BUFFER_TO_RESULT)
                                        |
                         PhysicalDistribute(columns=[sex])
                                        |
PhysicalAggregate(groupBy=[sex], output=[sex, partial_sum(age)], mode=INPUT_TO_BUFFER)
                                        |
    PhysicalAggregate(groupBy=[sex, age], output=[sex, age], mode=BUFFER_TO_BUFFER)
                                        |
                         PhysicalDistribute(columns=[sex, age])   # more columns to shuffle avoid data skew
                                        |
PhysicalAggregate(groupBy=[sex, age], output=[sex, age], mode=INPUT_TO_BUFFER)
                                        |
                          PhysicalOlapScan(name=student)
```

(cherry picked from commit 03f1cbd)
  • Loading branch information
924060929 committed Jun 7, 2024
1 parent f751ca4 commit cb1c156
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,6 @@ public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan>
// this means one stage gather agg, usually bad pattern
return false;
}
// forbid three or four stage distinct agg inter by distribute
if (agg.getAggMode() == AggMode.BUFFER_TO_BUFFER && children.get(0).getPlan() instanceof PhysicalDistribute) {
// if distinct without group by key, we prefer three or four stage distinct agg
// because the second phase of multi-distinct only have one instance, and it is slow generally.
if (agg.getGroupByExpressions().size() == 1
&& agg.getOutputExpressions().size() == 1) {
return true;
}
return false;
}

// forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle
// TODO: this is forbid good plan after cte reuse by mistake
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ public enum RuleType {
TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION),
THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
FOUR_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE(RuleTypeClass.IMPLEMENTATION),
LOGICAL_UNION_TO_PHYSICAL_UNION(RuleTypeClass.IMPLEMENTATION),
LOGICAL_EXCEPT_TO_PHYSICAL_EXCEPT(RuleTypeClass.IMPLEMENTATION),
LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -293,15 +294,89 @@ && couldConvertToMulti(agg))
// .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
/*
* sql:
* select count(distinct name), sum(age) from student;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(),
* output[count(partial_count(name)), sum(partial_sum(partial_sum(age)))],
* GATHER)
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(),
* output(partial_count(name), partial_sum(partial_sum(age))),
* hash distribute by name)
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(name),
* output(name, partial_sum(age)),
* hash_distribute by name)
* +--LOCAL(INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age)))
* +--scan(name, age)
*/
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.thenApplyMulti(ctx -> fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
groupByAndDistinct -> RequireProperties.of(
PhysicalProperties.createHash(
ctx.root.getDistinctArguments(), ShuffleType.REQUIRE
)
);
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGather =
agg -> RequireProperties.of(PhysicalProperties.GATHER);
return fourPhaseAggregateWithDistinct(
ctx.root, ctx.connectContext,
secondPhaseRequireDistinctHash, fourPhaseRequireGather
);
})
),
/*
* sql:
* select age, count(distinct name) from student group by age;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL(BUFFER_TO_RESULT, groupBy(age),
* output[age, sum(partial_count(name))],
* hash distribute by name)
* +--DISTINCT_LOCAL(INPUT_TO_BUFFER, groupBy(age),
* output(age, partial_count(name)),
* hash distribute by age, name)
* +--GLOBAL(BUFFER_TO_BUFFER, groupBy(age, name),
* output(age, name),
* hash_distribute by age, name)
* +--LOCAL(INPUT_TO_BUFFER, groupBy(age, name), output(age, name))
* +--scan(age, name)
*/
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build(
basePattern
.when(agg -> agg.everyDistinctArgumentNumIsOne() && !agg.getGroupByExpressions().isEmpty())
.when(agg ->
ImmutableSet.builder()
.addAll(agg.getGroupByExpressions())
.addAll(agg.getDistinctArguments())
.build().size() > agg.getGroupByExpressions().size()
)
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireGroupByAndDistinctHash =
groupByAndDistinct -> RequireProperties.of(
PhysicalProperties.createHash(groupByAndDistinct, ShuffleType.REQUIRE)
);

Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireGroupByHash =
agg -> RequireProperties.of(
PhysicalProperties.createHash(
agg.getGroupByExpressions(), ShuffleType.REQUIRE
)
);
return fourPhaseAggregateWithDistinct(
ctx.root, ctx.connectContext,
secondPhaseRequireGroupByAndDistinctHash, fourPhaseRequireGroupByHash
);
})
)
);
}
Expand Down Expand Up @@ -1649,19 +1724,10 @@ private boolean enablePushDownNoGroupAgg() {
return connectContext == null || connectContext.getSessionVariable().enablePushDownNoGroupAgg();
}

/**
* sql:
* select count(distinct name), sum(age) from student;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), sum(age#5)], [GATHER]
* +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), partial_sum(age)), hash distribute by name
* +--GLOBAL, BUFFER_TO_BUFFER, groupBy(name), output(name, partial_sum(age)), hash_distribute by name
* +--LOCAL, INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age))
* +--scan(name, age)
*/
private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistinct(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext,
Function<List<Expression>, RequireProperties> secondPhaseRequireSupplier,
Function<LogicalAggregate<? extends Plan>, RequireProperties> fourPhaseRequireSupplier) {
boolean couldBanned = couldConvertToMulti(logicalAgg);

Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();
Expand Down Expand Up @@ -1734,16 +1800,13 @@ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistin
globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}

RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);

RequireProperties requireDistinctHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getDistinctArguments(), ShuffleType.REQUIRE));
RequireProperties secondPhaseRequire = secondPhaseRequireSupplier.apply(localAggGroupBy);

//phase 2
PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, globalAggOutput, Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())),
bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
requireDistinctHash, anyLocalAgg);
secondPhaseRequire, anyLocalAgg);

// phase 3
AggregateParam distinctLocalParam = new AggregateParam(
Expand Down Expand Up @@ -1787,7 +1850,7 @@ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistin
PhysicalHashAggregate<? extends Plan> distinctLocal = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), localDistinctOutput, Optional.empty(),
distinctLocalParam, false, logicalAgg.getLogicalProperties(),
requireDistinctHash, anyLocalHashGlobalAgg);
secondPhaseRequire, anyLocalHashGlobalAgg);

//phase 4
AggregateParam distinctGlobalParam = new AggregateParam(
Expand All @@ -1801,7 +1864,7 @@ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistin
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
Expand All @@ -1821,10 +1884,12 @@ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistin
});
globalDistinctOutput.add(outputExprPhase4);
}

RequireProperties fourPhaseRequire = fourPhaseRequireSupplier.apply(logicalAgg);
PhysicalHashAggregate<? extends Plan> distinctGlobal = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), globalDistinctOutput, Optional.empty(),
distinctGlobalParam, false, logicalAgg.getLogicalProperties(),
requireGather, distinctLocal);
fourPhaseRequire, distinctLocal);

return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(distinctGlobal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* Common interface for logical/physical Aggregate.
Expand Down Expand Up @@ -68,4 +69,25 @@ default Set<Expression> getDistinctArguments() {
}
return distinctArguments.build();
}

/** everyDistinctArgumentNumIsOne */
default boolean everyDistinctArgumentNumIsOne() {
AtomicBoolean hasDistinctArguments = new AtomicBoolean(false);
for (NamedExpression outputExpression : getOutputExpressions()) {
boolean distinctArgumentSizeNotOne = outputExpression.anyMatch(expr -> {
if (expr instanceof AggregateFunction) {
AggregateFunction aggFun = (AggregateFunction) expr;
if (aggFun.isDistinct()) {
hasDistinctArguments.set(true);
return aggFun.getDistinctArguments().size() != 1;
}
}
return false;
});
if (distinctArgumentSizeNotOne) {
return false;
}
}
return hasDistinctArguments.get();
}
}
4 changes: 4 additions & 0 deletions regression-test/data/nereids_p0/aggregate/aggregate.out
Original file line number Diff line number Diff line change
Expand Up @@ -685,3 +685,7 @@ TESTING AGAIN

-- !having_with_limit --
7 -32767.0

-- !four_phase_full_distribute --
hello 1 1
world 1 1
22 changes: 22 additions & 0 deletions regression-test/suites/nereids_p0/aggregate/aggregate.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -347,4 +347,26 @@ suite("aggregate") {
sql "insert into table_10_undef_partitions2_keys3_properties4_distributed_by5(pk,col_bigint_undef_signed,col_varchar_10__undef_signed,col_varchar_64__undef_signed) values (0,111,'from','t'),(1,null,'h','out'),(2,3814,'get','q'),(3,5166561111626303305,'s','right'),(4,2688963514917402600,'b','hey'),(5,-5065987944147755706,'p','mean'),(6,31061,'v','d'),(7,122,'the','t'),(8,-2882446,'going','a'),(9,-43,'y','a');"

sql "SELECT MIN( `pk` ) FROM table_10_undef_partitions2_keys3_properties4_distributed_by5 WHERE ( col_varchar_64__undef_signed LIKE CONCAT ('come' , '%' ) OR col_varchar_10__undef_signed IN ( 'could' , 'was' , 'that' ) ) OR ( `pk` IS NULL OR ( `pk` <> 186 ) ) AND ( `pk` IS NOT NULL OR `pk` BETWEEN 255 AND -99 + 8 ) AND ( ( `pk` != 6 ) OR `pk` IS NULL );"

sql "drop table if exists test_four_phase_full_distribute"
sql """CREATE TABLE `test_four_phase_full_distribute` (
`id` INT NULL,
`age` INT NULL,
`name` VARCHAR(65533) NULL
) ENGINE=OLAP
DUPLICATE KEY(`id`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`id`) BUCKETS 10
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);"""

sql "insert into test_four_phase_full_distribute values(1, 21, 'hello'), (2, 22, 'world')"
sql " sync "
order_qt_four_phase_full_distribute """select
/*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT,THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI,THREE_PHASE_AGGREGATE_WITH_DISTINCT,FOUR_PHASE_AGGREGATE_WITH_DISTINCT')*/
name, count(distinct name), count(distinct age)
from test_four_phase_full_distribute
group by name
"""
}
4 changes: 3 additions & 1 deletion regression-test/suites/nereids_syntax_p0/agg_4_phase.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,6 @@ suite("agg_4_phase") {
qt_4phase (test_sql)

sql """select GROUP_CONCAT(distinct name, " ") from agg_4_phase_tbl;"""
}

sql """select /*+SET_VAR(disable_nereids_rules='TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI,THREE_PHASE_AGGREGATE_WITH_DISTINCT,FOUR_PHASE_AGGREGATE_WITH_DISTINCT')*/ GROUP_CONCAT(distinct name, " ") from agg_4_phase_tbl group by gender;"""
}

0 comments on commit cb1c156

Please sign in to comment.