Skip to content

Commit

Permalink
[Enhancement] Add count_distinct_implementation/enable_count_distinct…
Browse files Browse the repository at this point in the history
…_rewrite_by_hll_bitmap to control count distinct's implmentation (backport #52293) (#52334)

Signed-off-by: shuming.li <[email protected]>
Co-authored-by: shuming.li <[email protected]>
  • Loading branch information
mergify[bot] and LiShuMing authored Nov 11, 2024
1 parent 067d476 commit 20c61dd
Show file tree
Hide file tree
Showing 18 changed files with 588 additions and 33 deletions.
29 changes: 29 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,10 @@ public static MaterializedViewRewriteMode parse(String str) {
// binary, json, compact
public static final String THRIFT_PLAN_PROTOCOL = "thrift_plan_protocol";

public static final String COUNT_DISTINCT_IMPLEMENTATION = "count_distinct_implementation";

public static final String ENABLE_COUNT_DISTINCT_REWRITE_BY_HLL_BITMAP = "enable_count_distinct_rewrite_by_hll_bitmap";

// 0 means disable interleaving, positive value sets the group size, but adaptively enable interleaving,
// negative value means force interleaving under the group size of abs(interleaving_group_size)
public static final String INTERLEAVING_GROUP_SIZE = "interleaving_group_size";
Expand Down Expand Up @@ -1525,6 +1529,15 @@ public static MaterializedViewRewriteMode parse(String str) {
@VarAttr(name = DISABLE_GENERATED_COLUMN_REWRITE, flag = VariableMgr.INVISIBLE)
private boolean disableGeneratedColumnRewrite = false;

@VarAttr(name = COUNT_DISTINCT_IMPLEMENTATION)
private String countDistinctImplementation = "default";

// By default, we always use the created mv's bitmap/hll to rewrite count distinct, but result is not
// exactly matched with the original result.
// If we want to get the exactly matched result, we can disable this.
@VarAttr(name = ENABLE_COUNT_DISTINCT_REWRITE_BY_HLL_BITMAP)
private boolean enableCountDistinctRewriteByHllBitmap = true;

public int getCboPruneJsonSubfieldDepth() {
return cboPruneJsonSubfieldDepth;
}
Expand Down Expand Up @@ -4046,6 +4059,22 @@ public boolean isDisableGeneratedColumnRewrite() {
return disableGeneratedColumnRewrite;
}

public void setCountDistinctImplementation(String countDistinctImplementation) {
this.countDistinctImplementation = countDistinctImplementation;
}

public SessionVariableConstants.CountDistinctImplMode getCountDistinctImplementation() {
return SessionVariableConstants.CountDistinctImplMode.parse(countDistinctImplementation);
}

public boolean isEnableCountDistinctRewriteByHllBitmap() {
return enableCountDistinctRewriteByHllBitmap;
}

public void setEnableCountDistinctRewriteByHllBitmap(boolean enableCountDistinctRewriteByHllBitmap) {
this.enableCountDistinctRewriteByHllBitmap = enableCountDistinctRewriteByHllBitmap;
}

// Serialize to thrift object
// used for rest api
public TQueryOptions toThrift() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package com.starrocks.qe;

import org.apache.commons.lang3.EnumUtils;

public class SessionVariableConstants {

private SessionVariableConstants() {}
Expand Down Expand Up @@ -75,4 +77,15 @@ public enum AggregationStage {
THREE_STAGE,
FOUR_STAGE
}

// default, ndv, rewrite_by_hll_bitmap
public enum CountDistinctImplMode {
DEFAULT, // default, keeps the original count distinct implementation
NDV, // ndv, uses HyperLogLog to estimate the count distinct
MULTI_COUNT_DISTINCT;
public static String MODE_DEFAULT = DEFAULT.toString();
public static CountDistinctImplMode parse(String str) {
return EnumUtils.getEnumIgnoreCase(CountDistinctImplMode.class, str);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import com.starrocks.privilege.RolePrivilegeCollectionV2;
import com.starrocks.qe.ConnectContext;
import com.starrocks.qe.SessionVariable;
import com.starrocks.qe.SessionVariableConstants;
import com.starrocks.qe.SqlModeHelper;
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.server.RunMode;
Expand Down Expand Up @@ -1024,7 +1025,7 @@ public Void visitFunctionCall(FunctionCallExpr node, Scope scope) {
node.setNondeterministicId(exprId);
}

Function fn;
Function fn = null;
String fnName = node.getFnName().getFunction();

// throw exception direct
Expand Down Expand Up @@ -1194,10 +1195,32 @@ public Void visitFunctionCall(FunctionCallExpr node, Scope scope) {
}
}
}

} else {
fn = Expr.getBuiltinFunction(fnName, argumentTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
}

if (FunctionSet.COUNT.equalsIgnoreCase(fnName) && node.isDistinct() && node.getChildren().size() == 1) {
SessionVariableConstants.CountDistinctImplMode countDistinctImplementation =
session.getSessionVariable().getCountDistinctImplementation();
if (countDistinctImplementation != null) {
switch (countDistinctImplementation) {
case NDV:
node.resetFnName("", FunctionSet.NDV);
node.getParams().setIsDistinct(false);
fn = Expr.getBuiltinFunction(FunctionSet.NDV, argumentTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
break;
case MULTI_COUNT_DISTINCT:
node.resetFnName("", FunctionSet.MULTI_DISTINCT_COUNT);
node.getParams().setIsDistinct(false);
fn = Expr.getBuiltinFunction(FunctionSet.MULTI_DISTINCT_COUNT, argumentTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
break;
}
}
}

if (fn == null) {
fn = AnalyzerUtils.getUdfFunction(session, node.getFnName(), argumentTypes);
}
Expand Down Expand Up @@ -2089,10 +2112,10 @@ public Void visitSlot(SlotRef node, Scope scope) {

static class ResolveSlotVisitor extends Visitor {

private java.util.function.Consumer<SlotRef> resolver;
private Consumer<SlotRef> resolver;

public ResolveSlotVisitor(AnalyzeState state, ConnectContext session,
java.util.function.Consumer<SlotRef> slotResolver) {
Consumer<SlotRef> slotResolver) {
super(state, session);
resolver = slotResolver;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,17 @@ private static void analyzeSystemVariable(SystemVariable var) {
DataCachePopulateMode.fromName(resolvedExpression.getStringValue());
}

// count_distinct_implementation
if (variable.equalsIgnoreCase(SessionVariable.COUNT_DISTINCT_IMPLEMENTATION)) {
String rewriteModeName = resolvedExpression.getStringValue();
if (!EnumUtils.isValidEnumIgnoreCase(SessionVariableConstants.CountDistinctImplMode.class, rewriteModeName)) {
String supportedList = StringUtils.join(
EnumUtils.getEnumList(SessionVariableConstants.CountDistinctImplMode.class), ",");
throw new SemanticException(String.format("Unsupported count distinct implementation mode: %s, " +
"supported list is %s", rewriteModeName, supportedList));
}
}

var.setResolvedExpression(resolvedExpression);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.EquivalentShuttleContext;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.IRewriteEquivalent;
import com.starrocks.sql.optimizer.rule.tree.pdagg.AggregatePushDownContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -136,6 +137,7 @@ protected OptExpression viewBasedRewrite(RewriteContext rewriteContext, OptExpre
return null;
}
}
rewriteContext.setRollup(isRollup);

// normalize mv's aggs by using query's table ref and query ec
Map<ColumnRefOperator, ScalarOperator> mvProjection =
Expand Down Expand Up @@ -180,7 +182,7 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
// rewrite group by + aggregate functions
for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : swappedQueryColumnMap.entrySet()) {
ScalarOperator scalarOp = entry.getValue();
ScalarOperator rewritten = rewriteScalarOperator(entry.getValue(),
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, entry.getValue(),
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
Expand All @@ -202,7 +204,7 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
ScalarOperator scalarOp = entry.getValue();
ScalarOperator mapped = rewriteContext.getQueryColumnRefRewriter().rewrite(scalarOp.clone());
ScalarOperator swapped = columnRewriter.rewriteByQueryEc(mapped);
ScalarOperator rewritten = rewriteScalarOperator(swapped,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, swapped,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
Expand All @@ -215,7 +217,7 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
for (ColumnRefOperator groupKey : queryAggregationOperator.getGroupingKeys()) {
ScalarOperator mapped = rewriteContext.getQueryColumnRefRewriter().rewrite(groupKey.clone());
ScalarOperator swapped = columnRewriter.rewriteByQueryEc(mapped);
ScalarOperator rewritten = rewriteScalarOperator(swapped,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, swapped,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
Expand Down Expand Up @@ -243,7 +245,8 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
return mvOptExpr;
}

private ScalarOperator rewriteScalarOperator(ScalarOperator scalarOp,
private ScalarOperator rewriteScalarOperator(RewriteContext rewriteContext,
ScalarOperator scalarOp,
EquationRewriter equationRewriter,
Map<ColumnRefOperator, ColumnRefOperator> columnMapping,
ColumnRefSet originalColumnSet,
Expand All @@ -253,7 +256,9 @@ private ScalarOperator rewriteScalarOperator(ScalarOperator scalarOp,
}
equationRewriter.setAggregateFunctionRewriter(aggregateFunctionRewriter);
equationRewriter.setOutputMapping(columnMapping);
ScalarOperator rewritten = equationRewriter.replaceExprWithTarget(scalarOp);
Pair<ScalarOperator, EquivalentShuttleContext> result =
equationRewriter.replaceExprWithEquivalent(rewriteContext, scalarOp);
ScalarOperator rewritten = result.first;
if (rewritten == null || scalarOp == rewritten) {
return null;
}
Expand Down Expand Up @@ -370,7 +375,7 @@ private OptExpression rewriteForRollup(

// rewrite group by keys by using mv
List<ScalarOperator> newQueryGroupKeys = rewriteGroupKeys(
queryGroupingKeys, equationRewriter, rewriteContext.getOutputMapping(),
rewriteContext, queryGroupingKeys, equationRewriter, rewriteContext.getOutputMapping(),
new ColumnRefSet(rewriteContext.getQueryColumnSet()));
if (newQueryGroupKeys == null) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
Expand Down Expand Up @@ -608,14 +613,22 @@ private OptExpression createNewAggregate(
/**
* Rewrite group by keys by using MV.
*/
private List<ScalarOperator> rewriteGroupKeys(List<ScalarOperator> groupKeys,
private List<ScalarOperator> rewriteGroupKeys(RewriteContext rewriteContext,
List<ScalarOperator> groupKeys,
EquationRewriter equationRewriter,
Map<ColumnRefOperator, ColumnRefOperator> mapping,
ColumnRefSet queryColumnSet) {
List<ScalarOperator> newGroupByKeys = Lists.newArrayList();
equationRewriter.setOutputMapping(mapping);
for (ScalarOperator key : groupKeys) {
ScalarOperator newGroupByKey = equationRewriter.replaceExprWithTarget(key);
Pair<ScalarOperator, EquivalentShuttleContext> result = equationRewriter.replaceExprWithEquivalent(rewriteContext,
key, IRewriteEquivalent.RewriteEquivalentType.PREDICATE);
ScalarOperator newGroupByKey = result.first;
if (key.isVariable() && key == newGroupByKey) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
"Rewrite group by key failed: {}", key.toString());
return null;
}
if (key.isVariable() && key == newGroupByKey) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
"Rewrite group by key failed: {}", key.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,27 @@ public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate, Vo

private ScalarOperator rewriteByEquivalent(ScalarOperator input,
IRewriteEquivalent.RewriteEquivalentType type) {
if (!shuttleContext.isUseEquivalent() || !rewriteEquivalents.containsKey(type)) {
if (!shuttleContext.isUseEquivalent()) {
return null;
}
for (RewriteEquivalent equivalent : rewriteEquivalents.get(type)) {
ScalarOperator replaced = equivalent.rewrite(shuttleContext, columnMapping, input);
if (replaced != null) {
return replaced;
if (type.isAny()) {
for (List<RewriteEquivalent> equivalents : rewriteEquivalents.values()) {
for (RewriteEquivalent equivalent : equivalents) {
ScalarOperator replaced = equivalent.rewrite(shuttleContext, columnMapping, input);
if (replaced != null) {
return replaced;
}
}
}
} else {
if (!rewriteEquivalents.containsKey(type)) {
return null;
}
for (RewriteEquivalent equivalent : rewriteEquivalents.get(type)) {
ScalarOperator replaced = equivalent.rewrite(shuttleContext, columnMapping, input);
if (replaced != null) {
return replaced;
}
}
}
return null;
Expand Down Expand Up @@ -213,20 +227,49 @@ private boolean replaceColInExpr(ScalarOperator expr, ColumnRefOperator oldCol,
}
}

private final EquivalentShuttle shuttle = new EquivalentShuttle(new EquivalentShuttleContext(null, false, true));
private final EquivalentShuttle shuttle = new EquivalentShuttle(new EquivalentShuttleContext(null,
false, true, IRewriteEquivalent.RewriteEquivalentType.AGGREGATE));

protected ScalarOperator replaceExprWithTarget(ScalarOperator expr) {
return expr.accept(shuttle, null);
}

protected Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithRollup(RewriteContext rewriteContext,
ScalarOperator expr) {
/**
* Rewrite expr with equivalent shuttle which can be more robust/powerful than `replaceExprWithTarget`.
* NOTE: This method is mainly used in Aggregate's rewrite since there are more equivalences defined in Aggregate.
*/
public Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithRollup(RewriteContext rewriteContext,
ScalarOperator expr) {
return replaceExprWithEquivalent(rewriteContext, expr, IRewriteEquivalent.RewriteEquivalentType.AGGREGATE);
}

/**
* Replace expr with equivalent shuttle with specific type.
* @param rewriteContext rewrite context
* @param expr input expr to be rewritten
* @param type equivalent type which is used for call operator rewrite to deduce rewriting strategy
* @return rewritten expr and equivalent shuttle context
*/
public Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithEquivalent(
RewriteContext rewriteContext,
ScalarOperator expr,
IRewriteEquivalent.RewriteEquivalentType type) {
boolean isRollup = rewriteContext.isRollup();
final EquivalentShuttleContext shuttleContext = new EquivalentShuttleContext(rewriteContext,
true, true);
isRollup, true, type);
final EquivalentShuttle shuttle = new EquivalentShuttle(shuttleContext);
return Pair.create(expr.accept(shuttle, null), shuttleContext);
}

/**
* Replace expr with equivalent shuttle, by default, we can rewrite call operator with any type of equivalent
* since call operator can be aggregate or predicate or group by keys.
*/
public Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithEquivalent(RewriteContext rewriteContext,
ScalarOperator expr) {
return replaceExprWithEquivalent(rewriteContext, expr, IRewriteEquivalent.RewriteEquivalentType.ANY);
}

public boolean containsKey(ScalarOperator scalarOperator) {
return equationMap.containsKey(scalarOperator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,8 @@ private OptExpression rewriteComplete(List<Table> queryTables,
final RewriteContext rewriteContext = new RewriteContext(
queryExpression, queryPredicateSplit, queryEc, queryRelationIdToColumns, queryColumnRefFactory,
mvRewriteContext.getQueryColumnRefRewriter(), mvExpression, mvPredicateSplit, mvRelationIdToColumns,
mvColumnRefFactory, mvColumnRefRewriter, materializationContext.getOutputMapping(), queryColumnSet);
mvColumnRefFactory, mvColumnRefRewriter, materializationContext.getOutputMapping(), queryColumnSet,
optimizerContext);
// add agg push down rewrite info
rewriteContext.setAggregatePushDownContext(mvRewriteContext.getAggregatePushDownContext());

Expand Down
Loading

0 comments on commit 20c61dd

Please sign in to comment.