Skip to content

Commit

Permalink
[fix](Nereids) fix group concat (apache#33091)
Browse files Browse the repository at this point in the history
Fix failed in regression_test/suites/query_p0/group_concat/test_group_concat.groovy

select
group_concat( distinct b1, '?'), group_concat( distinct b3, '?')
from
table_group_concat
group by
b2

exception:

lowestCostPlans with physicalProperties(GATHER) doesn't exist in root group

The root cause is '?' is push down to slot by NormalizeAggregate, AggregateStrategies treat the slot as a distinct parameter and generate a invalid PhysicalHashAggregate, and then reject by ChildOutputPropertyDeriver.

I fix this bug by avoid push down literal to slot in NormalizeAggregate, and forbidden generate stream aggregate node when group by slots is empty

(cherry picked from commit e7d6697)
  • Loading branch information
924060929 committed May 24, 2024
1 parent e51f8ba commit 33c783f
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 14 deletions.
9 changes: 7 additions & 2 deletions be/src/pipeline/pipeline_fragment_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,12 @@ Status PipelineFragmentContext::_build_pipelines(ExecNode* node, PipelinePtr cur
auto* agg_node = dynamic_cast<vectorized::AggregationNode*>(node);
auto new_pipe = add_pipeline();
RETURN_IF_ERROR(_build_pipelines(node->child(0), new_pipe));
if (agg_node->is_aggregate_evaluators_empty()) {
if (agg_node->is_probe_expr_ctxs_empty() && node->row_desc().num_slots() == 0) {
return Status::InternalError("Illegal aggregate node " +
std::to_string(agg_node->id()) +
": group by and output is empty");
}
if (agg_node->is_aggregate_evaluators_empty() && !agg_node->is_probe_expr_ctxs_empty()) {
auto data_queue = std::make_shared<DataQueue>(1);
OperatorBuilderPtr pre_agg_sink =
std::make_shared<DistinctStreamingAggSinkOperatorBuilder>(node->id(), agg_node,
Expand All @@ -524,7 +529,7 @@ Status PipelineFragmentContext::_build_pipelines(ExecNode* node, PipelinePtr cur
std::make_shared<DistinctStreamingAggSourceOperatorBuilder>(
node->id(), agg_node, data_queue);
RETURN_IF_ERROR(cur_pipe->add_operator(pre_agg_source));
} else if (agg_node->is_streaming_preagg()) {
} else if (agg_node->is_streaming_preagg() && !agg_node->is_probe_expr_ctxs_empty()) {
auto data_queue = std::make_shared<DataQueue>(1);
OperatorBuilderPtr pre_agg_sink = std::make_shared<StreamingAggSinkOperatorBuilder>(
node->id(), agg_node, data_queue);
Expand Down
5 changes: 5 additions & 0 deletions be/src/runtime/descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,12 @@ class RowDescriptor {
_has_varlen_slots(desc._has_varlen_slots) {
_num_materialized_slots = 0;
_num_null_slots = 0;
_num_slots = 0;
std::vector<TupleDescriptor*>::const_iterator it = desc._tuple_desc_map.begin();
for (; it != desc._tuple_desc_map.end(); ++it) {
_num_materialized_slots += (*it)->num_materialized_slots();
_num_null_slots += (*it)->num_null_slots();
_num_slots += (*it)->slots().size();
}
_num_null_bytes = (_num_null_slots + 7) / 8;
}
Expand All @@ -528,6 +530,8 @@ class RowDescriptor {

int num_null_bytes() const { return _num_null_bytes; }

int num_slots() const { return _num_slots; }

static const int INVALID_IDX;

// Returns INVALID_IDX if id not part of this row.
Expand Down Expand Up @@ -582,6 +586,7 @@ class RowDescriptor {
int _num_materialized_slots;
int _num_null_slots;
int _num_null_bytes;
int _num_slots;
};

} // namespace doris
1 change: 1 addition & 0 deletions be/src/vec/exec/vaggregation_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ class AggregationNode : public ::doris::ExecNode {
Status pull(doris::RuntimeState* state, vectorized::Block* output_block, bool* eos) override;
Status sink(doris::RuntimeState* state, vectorized::Block* input_block, bool eos) override;
Status do_pre_agg(vectorized::Block* input_block, vectorized::Block* output_block);
bool is_probe_expr_ctxs_empty() const { return _probe_expr_ctxs.empty(); }
bool is_streaming_preagg() const { return _is_streaming_preagg; }
bool is_aggregate_evaluators_empty() const { return _aggregate_evaluators.empty(); }
void _make_nullable_output_key(Block* block);
Expand Down
22 changes: 19 additions & 3 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

Expand Down Expand Up @@ -65,7 +66,7 @@ public class Group {

// Map of cost lower bounds
// Map required plan props to cost lower bound of corresponding plan
private final Map<PhysicalProperties, Pair<Cost, GroupExpression>> lowestCostPlans = Maps.newHashMap();
private final Map<PhysicalProperties, Pair<Cost, GroupExpression>> lowestCostPlans = Maps.newLinkedHashMap();

private boolean isExplored = false;

Expand Down Expand Up @@ -213,6 +214,12 @@ public Optional<Pair<Cost, GroupExpression>> getLowestCostPlan(PhysicalPropertie
return Optional.ofNullable(lowestCostPlans.get(physicalProperties));
}

public Map<PhysicalProperties, Cost> getLowestCosts() {
return lowestCostPlans.entrySet()
.stream()
.collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> kv.getValue().first));
}

public GroupExpression getBestPlan(PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
return lowestCostPlans.get(properties).second;
Expand Down Expand Up @@ -451,9 +458,18 @@ public String toString() {
public String treeString() {
Function<Object, String> toString = obj -> {
if (obj instanceof Group) {
return "Group[" + ((Group) obj).groupId + "]";
Group group = (Group) obj;
Map<PhysicalProperties, Cost> lowestCosts = group.getLowestCosts();
return "Group[" + group.groupId + ", lowestCosts: " + lowestCosts + "]";
} else if (obj instanceof GroupExpression) {
return ((GroupExpression) obj).getPlan().toString();
GroupExpression groupExpression = (GroupExpression) obj;
Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> lowestCostTable
= groupExpression.getLowestCostTable();
Map<PhysicalProperties, PhysicalProperties> requestPropertiesMap
= groupExpression.getRequestPropertiesMap();
Cost cost = groupExpression.getCost();
return groupExpression.getPlan().toString() + " [cost: " + cost + ", lowestCostTable: "
+ lowestCostTable + ", requestPropertiesMap: " + requestPropertiesMap + "]";
} else if (obj instanceof Pair) {
// print logicalExpressions or physicalExpressions
// first is name, second is group expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

Expand Down Expand Up @@ -319,6 +320,10 @@ public void setEstOutputRowCount(double estOutputRowCount) {
this.estOutputRowCount = estOutputRowCount;
}

public Map<PhysicalProperties, PhysicalProperties> getRequestPropertiesMap() {
return ImmutableMap.copyOf(requestPropertiesMap);
}

@Override
public String toString() {
DecimalFormat format = new DecimalFormat("#,###.##");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ public Boolean visit(Plan plan, Void context) {

@Override
public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg, Void context) {
if (agg.getGroupByExpressions().isEmpty() && agg.getOutputExpressions().isEmpty()) {
return false;
}
if (!agg.getAggregateParam().canBeBanned) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ public static PhysicalProperties createHash(
.map(SlotReference.class::cast)
.map(SlotReference::getExprId)
.collect(Collectors.toList());
return createHash(partitionedSlots, shuffleType);
return partitionedSlots.isEmpty() ? PhysicalProperties.GATHER : createHash(partitionedSlots, shuffleType);
}

public static PhysicalProperties createHash(List<ExprId> orderedShuffledColumns, ShuffleType shuffleType) {
return new PhysicalProperties(new DistributionSpecHash(orderedShuffledColumns, shuffleType));
return orderedShuffledColumns.isEmpty()
? PhysicalProperties.GATHER
: new PhysicalProperties(new DistributionSpecHash(orderedShuffledColumns, shuffleType));
}

public static PhysicalProperties createHash(DistributionSpecHash distributionSpecHash) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,30 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, LogicalHaving
// split non-distinct agg child as two part
// TRUE part 1: need push down itself, if it contains subqury or window expression
// FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression
Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
Map<Boolean, ImmutableSet<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct())
.flatMap(agg -> agg.children().stream())
// should not push down literal under aggregate
// e.g. group_concat(distinct xxx, ','), the ',' literal show stay in aggregate
.filter(arg -> !(arg instanceof Literal))
.collect(Collectors.groupingBy(
child -> child.containsType(SubqueryExpr.class, WindowExpression.class),
Collectors.toSet()));
ImmutableSet.toImmutableSet()));

// split distinct agg child as two parts
// TRUE part 1: need push down itself, if it is NOT SlotReference or Literal
// FALSE part 2: need push down its input slots, if it is SlotReference or Literal
Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream())
.collect(Collectors.groupingBy(
child -> !(child instanceof SlotReference || child instanceof Literal),
Collectors.toSet()));
Map<Boolean, ImmutableSet<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> aggFunc.isDistinct())
.flatMap(agg -> agg.children().stream())
// should not push down literal under aggregate
// e.g. group_concat(distinct xxx, ','), the ',' literal show stay in aggregate
.filter(arg -> !(arg instanceof Literal))
.collect(
Collectors.groupingBy(
child -> !(child instanceof SlotReference),
ImmutableSet.toImmutableSet())
);

Set<Expression> needPushSelf = Sets.union(
categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate.PushDownAggOp;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;
Expand Down Expand Up @@ -1292,6 +1293,15 @@ private List<PhysicalHashAggregate<? extends Plan>> threePhaseAggregateWithDisti
.build();

List<Expression> localAggGroupBy = ImmutableList.copyOf(localAggGroupBySet);
boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() && localAggOutput.isEmpty();

// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
localAggGroupBy = ImmutableList.of(new NullLiteral(TinyIntType.INSTANCE));
localAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}

boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, localAggGroupBy);
List<Expression> partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg);
RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY);
Expand All @@ -1317,6 +1327,12 @@ private List<PhysicalHashAggregate<? extends Plan>> threePhaseAggregateWithDisti
.addAll(nonDistinctAggFunctionToAliasPhase2.values())
.build();

// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}

RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);
PhysicalHashAggregate<Plan> anyLocalGatherGlobalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, globalAggOutput, Optional.of(partitionExpressions),
Expand Down Expand Up @@ -1680,6 +1696,16 @@ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistin
boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, localAggGroupBy);
List<Expression> partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg);
RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY);

boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() && localAggOutput.isEmpty();

// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
localAggGroupBy = ImmutableList.of(new NullLiteral(TinyIntType.INSTANCE));
localAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}

PhysicalHashAggregate<Plan> anyLocalAgg = new PhysicalHashAggregate<>(localAggGroupBy,
localAggOutput, Optional.of(partitionExpressions), inputToBufferParam,
maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(),
Expand All @@ -1702,6 +1728,12 @@ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistin
.addAll(nonDistinctAggFunctionToAliasPhase2.values())
.build();

// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}

RequireProperties requireGather = RequireProperties.of(PhysicalProperties.GATHER);

RequireProperties requireDistinctHash = RequireProperties.of(
Expand Down

0 comments on commit 33c783f

Please sign in to comment.