Skip to content

Commit

Permalink
[fix](Nereids) fix group_concat(distinct) failed (apache#31873)
Browse files Browse the repository at this point in the history
(cherry picked from commit 7ed0263)
  • Loading branch information
924060929 committed Mar 7, 2024
1 parent b2ccaab commit 4d26241
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ inputToBufferParam, maybeUsingStreamAgg(connectContext, logicalAgg),
* <p>
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
Expand All @@ -1118,12 +1118,10 @@ inputToBufferParam, maybeUsingStreamAgg(connectContext, logicalAgg),
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=INPUT_TO_BUFFER)
* |
* PhysicalDistribute(distributionSpec=HASH(name))
* |
* LogicalOlapScan(table=tbl, **if distribute by name**)
*
*/
Expand Down Expand Up @@ -1175,8 +1173,9 @@ private List<PhysicalHashAggregate<? extends Plan>> twoPhaseAggregateWithDistinc
if (outputChild instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) outputChild;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
Expand Down Expand Up @@ -1236,7 +1235,7 @@ private List<PhysicalHashAggregate<? extends Plan>> twoPhaseAggregateWithDistinc
* after:
* single node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
* |
Expand All @@ -1248,7 +1247,7 @@ private List<PhysicalHashAggregate<? extends Plan>> twoPhaseAggregateWithDistinc
* <p>
* distribute node aggregate:
* <p>
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id)], mode=BUFFER_TO_RESULT)
* PhysicalHashAggregate(groupBy=[name], output=[name, count(distinct(id))], mode=BUFFER_TO_RESULT)
* |
* PhysicalHashAggregate(groupBy=[name, id], output=[name, id], mode=BUFFER_TO_BUFFER)
* |
Expand Down Expand Up @@ -1331,14 +1330,14 @@ private List<PhysicalHashAggregate<? extends Plan>> threePhaseAggregateWithDisti
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
return new AggregateExpression(nonDistinct,
bufferToResultParam, aggregateFunction.child(0));
return new AggregateExpression(nonDistinct, bufferToResultParam, aggregateFunction);
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr);
return new AggregateExpression(aggregateFunction,
Expand Down Expand Up @@ -1727,8 +1726,9 @@ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistin
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
Expand Down Expand Up @@ -1767,8 +1767,9 @@ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistin
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1,
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,7 @@ public String toString() {
return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")";
}

public List<Expression> getDistinctArguments() {
return distinct ? getArguments() : ImmutableList.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ public boolean nullable() {
.anyMatch(expression -> !(expression instanceof OrderExpression) && expression.nullable());
}

@Override
public List<Expression> getDistinctArguments() {
if (distinct) {
return ImmutableList.of(getArgument(0));
} else {
return ImmutableList.of();
}
}

@Override
public void checkLegalityBeforeTypeCoercion() {
DataType typeOrArg0 = getArgumentType(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ default Set<AggregateFunction> getAggregateFunctions() {
default Set<Expression> getDistinctArguments() {
return getAggregateFunctions().stream()
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.flatMap(aggregateFunction -> aggregateFunction.getDistinctArguments().stream())
.collect(ImmutableSet.toImmutableSet());
}
}
8 changes: 8 additions & 0 deletions regression-test/data/nereids_syntax_p0/group_concat.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !group_by_distinct --
1 \N
2 a
3 b
4 c
5 \N

45 changes: 43 additions & 2 deletions regression-test/suites/nereids_syntax_p0/group_concat.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,47 @@ suite("group_concat") {
sql "select group_concat(cast(number as string), NULL) from numbers('number'='10')"
result([[null]])
}



def testGroupByDistinct = {
sql "drop table if exists test_group_concat_distinct_tbl1"
sql """create table test_group_concat_distinct_tbl1(
tbl1_id1 int
) distributed by hash(tbl1_id1)
properties('replication_num'='1')
"""

sql "insert into test_group_concat_distinct_tbl1 values(1), (2), (3), (4), (5)"


sql "drop table if exists test_group_concat_distinct_tbl2"
sql """create table test_group_concat_distinct_tbl2(
tbl2_id1 int,
tbl2_id2 int,
) distributed by hash(tbl2_id1)
properties('replication_num'='1')
"""
sql "insert into test_group_concat_distinct_tbl2 values(1, 11), (2, 22), (3, 33), (4, 44)"


sql "drop table if exists test_group_concat_distinct_tbl3"
sql """create table test_group_concat_distinct_tbl3(
tbl3_id2 int,
tbl3_name varchar(255)
) distributed by hash(tbl3_id2)
properties('replication_num'='1')
"""
sql "insert into test_group_concat_distinct_tbl3 values(22, 'a'), (33, 'b'), (44, 'c')"

sql "sync"

order_qt_group_by_distinct """
SELECT
tbl1.tbl1_id1,
group_concat(DISTINCT tbl3.tbl3_name, ',') AS `names`
FROM test_group_concat_distinct_tbl1 tbl1
LEFT OUTER JOIN test_group_concat_distinct_tbl2 tbl2 ON tbl2.tbl2_id1 = tbl1.tbl1_id1
LEFT OUTER JOIN test_group_concat_distinct_tbl3 tbl3 ON tbl3.tbl3_id2 = tbl2.tbl2_id2
GROUP BY tbl1.tbl1_id1
"""
}()
}

0 comments on commit 4d26241

Please sign in to comment.