Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Add count_distinct_implementation/enable_count_distinct_rewrite_by_hll_bitmap to control count distinct's implmentation (backport #52293) #52334

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,10 @@ public static MaterializedViewRewriteMode parse(String str) {
// binary, json, compact
public static final String THRIFT_PLAN_PROTOCOL = "thrift_plan_protocol";

public static final String COUNT_DISTINCT_IMPLEMENTATION = "count_distinct_implementation";

public static final String ENABLE_COUNT_DISTINCT_REWRITE_BY_HLL_BITMAP = "enable_count_distinct_rewrite_by_hll_bitmap";

// 0 means disable interleaving, positive value sets the group size, but adaptively enable interleaving,
// negative value means force interleaving under the group size of abs(interleaving_group_size)
public static final String INTERLEAVING_GROUP_SIZE = "interleaving_group_size";
Expand Down Expand Up @@ -1525,6 +1529,15 @@ public static MaterializedViewRewriteMode parse(String str) {
@VarAttr(name = DISABLE_GENERATED_COLUMN_REWRITE, flag = VariableMgr.INVISIBLE)
private boolean disableGeneratedColumnRewrite = false;

@VarAttr(name = COUNT_DISTINCT_IMPLEMENTATION)
private String countDistinctImplementation = "default";

// By default, we always use the created mv's bitmap/hll to rewrite count distinct, but result is not
// exactly matched with the original result.
// If we want to get the exactly matched result, we can disable this.
@VarAttr(name = ENABLE_COUNT_DISTINCT_REWRITE_BY_HLL_BITMAP)
private boolean enableCountDistinctRewriteByHllBitmap = true;

public int getCboPruneJsonSubfieldDepth() {
return cboPruneJsonSubfieldDepth;
}
Expand Down Expand Up @@ -4046,6 +4059,22 @@ public boolean isDisableGeneratedColumnRewrite() {
return disableGeneratedColumnRewrite;
}

public void setCountDistinctImplementation(String countDistinctImplementation) {
this.countDistinctImplementation = countDistinctImplementation;
}

public SessionVariableConstants.CountDistinctImplMode getCountDistinctImplementation() {
return SessionVariableConstants.CountDistinctImplMode.parse(countDistinctImplementation);
}

public boolean isEnableCountDistinctRewriteByHllBitmap() {
return enableCountDistinctRewriteByHllBitmap;
}

public void setEnableCountDistinctRewriteByHllBitmap(boolean enableCountDistinctRewriteByHllBitmap) {
this.enableCountDistinctRewriteByHllBitmap = enableCountDistinctRewriteByHllBitmap;
}

// Serialize to thrift object
// used for rest api
public TQueryOptions toThrift() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package com.starrocks.qe;

import org.apache.commons.lang3.EnumUtils;

public class SessionVariableConstants {

private SessionVariableConstants() {}
Expand Down Expand Up @@ -75,4 +77,15 @@ public enum AggregationStage {
THREE_STAGE,
FOUR_STAGE
}

// default, ndv, rewrite_by_hll_bitmap
public enum CountDistinctImplMode {
DEFAULT, // default, keeps the original count distinct implementation
NDV, // ndv, uses HyperLogLog to estimate the count distinct
MULTI_COUNT_DISTINCT;
public static String MODE_DEFAULT = DEFAULT.toString();
public static CountDistinctImplMode parse(String str) {
return EnumUtils.getEnumIgnoreCase(CountDistinctImplMode.class, str);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import com.starrocks.privilege.RolePrivilegeCollectionV2;
import com.starrocks.qe.ConnectContext;
import com.starrocks.qe.SessionVariable;
import com.starrocks.qe.SessionVariableConstants;
import com.starrocks.qe.SqlModeHelper;
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.server.RunMode;
Expand Down Expand Up @@ -1024,7 +1025,7 @@ public Void visitFunctionCall(FunctionCallExpr node, Scope scope) {
node.setNondeterministicId(exprId);
}

Function fn;
Function fn = null;
String fnName = node.getFnName().getFunction();

// throw exception direct
Expand Down Expand Up @@ -1194,10 +1195,32 @@ public Void visitFunctionCall(FunctionCallExpr node, Scope scope) {
}
}
}

} else {
fn = Expr.getBuiltinFunction(fnName, argumentTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
}

if (FunctionSet.COUNT.equalsIgnoreCase(fnName) && node.isDistinct() && node.getChildren().size() == 1) {
SessionVariableConstants.CountDistinctImplMode countDistinctImplementation =
session.getSessionVariable().getCountDistinctImplementation();
if (countDistinctImplementation != null) {
switch (countDistinctImplementation) {
case NDV:
node.resetFnName("", FunctionSet.NDV);
node.getParams().setIsDistinct(false);
fn = Expr.getBuiltinFunction(FunctionSet.NDV, argumentTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
break;
case MULTI_COUNT_DISTINCT:
node.resetFnName("", FunctionSet.MULTI_DISTINCT_COUNT);
node.getParams().setIsDistinct(false);
fn = Expr.getBuiltinFunction(FunctionSet.MULTI_DISTINCT_COUNT, argumentTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
break;
}
}
}

if (fn == null) {
fn = AnalyzerUtils.getUdfFunction(session, node.getFnName(), argumentTypes);
}
Expand Down Expand Up @@ -2089,10 +2112,10 @@ public Void visitSlot(SlotRef node, Scope scope) {

static class ResolveSlotVisitor extends Visitor {

private java.util.function.Consumer<SlotRef> resolver;
private Consumer<SlotRef> resolver;

public ResolveSlotVisitor(AnalyzeState state, ConnectContext session,
java.util.function.Consumer<SlotRef> slotResolver) {
Consumer<SlotRef> slotResolver) {
super(state, session);
resolver = slotResolver;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,17 @@ private static void analyzeSystemVariable(SystemVariable var) {
DataCachePopulateMode.fromName(resolvedExpression.getStringValue());
}

// count_distinct_implementation
if (variable.equalsIgnoreCase(SessionVariable.COUNT_DISTINCT_IMPLEMENTATION)) {
String rewriteModeName = resolvedExpression.getStringValue();
if (!EnumUtils.isValidEnumIgnoreCase(SessionVariableConstants.CountDistinctImplMode.class, rewriteModeName)) {
String supportedList = StringUtils.join(
EnumUtils.getEnumList(SessionVariableConstants.CountDistinctImplMode.class), ",");
throw new SemanticException(String.format("Unsupported count distinct implementation mode: %s, " +
"supported list is %s", rewriteModeName, supportedList));
}
}

var.setResolvedExpression(resolvedExpression);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.EquivalentShuttleContext;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.IRewriteEquivalent;
import com.starrocks.sql.optimizer.rule.tree.pdagg.AggregatePushDownContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -136,6 +137,7 @@ protected OptExpression viewBasedRewrite(RewriteContext rewriteContext, OptExpre
return null;
}
}
rewriteContext.setRollup(isRollup);

// normalize mv's aggs by using query's table ref and query ec
Map<ColumnRefOperator, ScalarOperator> mvProjection =
Expand Down Expand Up @@ -180,7 +182,7 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
// rewrite group by + aggregate functions
for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : swappedQueryColumnMap.entrySet()) {
ScalarOperator scalarOp = entry.getValue();
ScalarOperator rewritten = rewriteScalarOperator(entry.getValue(),
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, entry.getValue(),
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
Expand All @@ -202,7 +204,7 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
ScalarOperator scalarOp = entry.getValue();
ScalarOperator mapped = rewriteContext.getQueryColumnRefRewriter().rewrite(scalarOp.clone());
ScalarOperator swapped = columnRewriter.rewriteByQueryEc(mapped);
ScalarOperator rewritten = rewriteScalarOperator(swapped,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, swapped,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
Expand All @@ -215,7 +217,7 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
for (ColumnRefOperator groupKey : queryAggregationOperator.getGroupingKeys()) {
ScalarOperator mapped = rewriteContext.getQueryColumnRefRewriter().rewrite(groupKey.clone());
ScalarOperator swapped = columnRewriter.rewriteByQueryEc(mapped);
ScalarOperator rewritten = rewriteScalarOperator(swapped,
ScalarOperator rewritten = rewriteScalarOperator(rewriteContext, swapped,
queryExprToMvExprRewriter, rewriteContext.getOutputMapping(),
originalColumnSet, aggregateFunctionRewriter);
if (rewritten == null) {
Expand Down Expand Up @@ -243,7 +245,8 @@ protected OptExpression rewriteProjection(RewriteContext rewriteContext,
return mvOptExpr;
}

private ScalarOperator rewriteScalarOperator(ScalarOperator scalarOp,
private ScalarOperator rewriteScalarOperator(RewriteContext rewriteContext,
ScalarOperator scalarOp,
EquationRewriter equationRewriter,
Map<ColumnRefOperator, ColumnRefOperator> columnMapping,
ColumnRefSet originalColumnSet,
Expand All @@ -253,7 +256,9 @@ private ScalarOperator rewriteScalarOperator(ScalarOperator scalarOp,
}
equationRewriter.setAggregateFunctionRewriter(aggregateFunctionRewriter);
equationRewriter.setOutputMapping(columnMapping);
ScalarOperator rewritten = equationRewriter.replaceExprWithTarget(scalarOp);
Pair<ScalarOperator, EquivalentShuttleContext> result =
equationRewriter.replaceExprWithEquivalent(rewriteContext, scalarOp);
ScalarOperator rewritten = result.first;
if (rewritten == null || scalarOp == rewritten) {
return null;
}
Expand Down Expand Up @@ -370,7 +375,7 @@ private OptExpression rewriteForRollup(

// rewrite group by keys by using mv
List<ScalarOperator> newQueryGroupKeys = rewriteGroupKeys(
queryGroupingKeys, equationRewriter, rewriteContext.getOutputMapping(),
rewriteContext, queryGroupingKeys, equationRewriter, rewriteContext.getOutputMapping(),
new ColumnRefSet(rewriteContext.getQueryColumnSet()));
if (newQueryGroupKeys == null) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
Expand Down Expand Up @@ -608,14 +613,22 @@ private OptExpression createNewAggregate(
/**
* Rewrite group by keys by using MV.
*/
private List<ScalarOperator> rewriteGroupKeys(List<ScalarOperator> groupKeys,
private List<ScalarOperator> rewriteGroupKeys(RewriteContext rewriteContext,
List<ScalarOperator> groupKeys,
EquationRewriter equationRewriter,
Map<ColumnRefOperator, ColumnRefOperator> mapping,
ColumnRefSet queryColumnSet) {
List<ScalarOperator> newGroupByKeys = Lists.newArrayList();
equationRewriter.setOutputMapping(mapping);
for (ScalarOperator key : groupKeys) {
ScalarOperator newGroupByKey = equationRewriter.replaceExprWithTarget(key);
Pair<ScalarOperator, EquivalentShuttleContext> result = equationRewriter.replaceExprWithEquivalent(rewriteContext,
key, IRewriteEquivalent.RewriteEquivalentType.PREDICATE);
ScalarOperator newGroupByKey = result.first;
if (key.isVariable() && key == newGroupByKey) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
"Rewrite group by key failed: {}", key.toString());
return null;
}
if (key.isVariable() && key == newGroupByKey) {
OptimizerTraceUtil.logMVRewriteFailReason(mvRewriteContext.getMVName(),
"Rewrite group by key failed: {}", key.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,27 @@ public ScalarOperator visitBinaryPredicate(BinaryPredicateOperator predicate, Vo

private ScalarOperator rewriteByEquivalent(ScalarOperator input,
IRewriteEquivalent.RewriteEquivalentType type) {
if (!shuttleContext.isUseEquivalent() || !rewriteEquivalents.containsKey(type)) {
if (!shuttleContext.isUseEquivalent()) {
return null;
}
for (RewriteEquivalent equivalent : rewriteEquivalents.get(type)) {
ScalarOperator replaced = equivalent.rewrite(shuttleContext, columnMapping, input);
if (replaced != null) {
return replaced;
if (type.isAny()) {
for (List<RewriteEquivalent> equivalents : rewriteEquivalents.values()) {
for (RewriteEquivalent equivalent : equivalents) {
ScalarOperator replaced = equivalent.rewrite(shuttleContext, columnMapping, input);
if (replaced != null) {
return replaced;
}
}
}
} else {
if (!rewriteEquivalents.containsKey(type)) {
return null;
}
for (RewriteEquivalent equivalent : rewriteEquivalents.get(type)) {
ScalarOperator replaced = equivalent.rewrite(shuttleContext, columnMapping, input);
if (replaced != null) {
return replaced;
}
}
}
return null;
Expand Down Expand Up @@ -213,20 +227,49 @@ private boolean replaceColInExpr(ScalarOperator expr, ColumnRefOperator oldCol,
}
}

private final EquivalentShuttle shuttle = new EquivalentShuttle(new EquivalentShuttleContext(null, false, true));
private final EquivalentShuttle shuttle = new EquivalentShuttle(new EquivalentShuttleContext(null,
false, true, IRewriteEquivalent.RewriteEquivalentType.AGGREGATE));

protected ScalarOperator replaceExprWithTarget(ScalarOperator expr) {
return expr.accept(shuttle, null);
}

protected Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithRollup(RewriteContext rewriteContext,
ScalarOperator expr) {
/**
* Rewrite expr with equivalent shuttle which can be more robust/powerful than `replaceExprWithTarget`.
* NOTE: This method is mainly used in Aggregate's rewrite since there are more equivalences defined in Aggregate.
*/
public Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithRollup(RewriteContext rewriteContext,
ScalarOperator expr) {
return replaceExprWithEquivalent(rewriteContext, expr, IRewriteEquivalent.RewriteEquivalentType.AGGREGATE);
}

/**
* Replace expr with equivalent shuttle with specific type.
* @param rewriteContext rewrite context
* @param expr input expr to be rewritten
* @param type equivalent type which is used for call operator rewrite to deduce rewriting strategy
* @return rewritten expr and equivalent shuttle context
*/
public Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithEquivalent(
RewriteContext rewriteContext,
ScalarOperator expr,
IRewriteEquivalent.RewriteEquivalentType type) {
boolean isRollup = rewriteContext.isRollup();
final EquivalentShuttleContext shuttleContext = new EquivalentShuttleContext(rewriteContext,
true, true);
isRollup, true, type);
final EquivalentShuttle shuttle = new EquivalentShuttle(shuttleContext);
return Pair.create(expr.accept(shuttle, null), shuttleContext);
}

/**
* Replace expr with equivalent shuttle, by default, we can rewrite call operator with any type of equivalent
* since call operator can be aggregate or predicate or group by keys.
*/
public Pair<ScalarOperator, EquivalentShuttleContext> replaceExprWithEquivalent(RewriteContext rewriteContext,
ScalarOperator expr) {
return replaceExprWithEquivalent(rewriteContext, expr, IRewriteEquivalent.RewriteEquivalentType.ANY);
}

public boolean containsKey(ScalarOperator scalarOperator) {
return equationMap.containsKey(scalarOperator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,8 @@ private OptExpression rewriteComplete(List<Table> queryTables,
final RewriteContext rewriteContext = new RewriteContext(
queryExpression, queryPredicateSplit, queryEc, queryRelationIdToColumns, queryColumnRefFactory,
mvRewriteContext.getQueryColumnRefRewriter(), mvExpression, mvPredicateSplit, mvRelationIdToColumns,
mvColumnRefFactory, mvColumnRefRewriter, materializationContext.getOutputMapping(), queryColumnSet);
mvColumnRefFactory, mvColumnRefRewriter, materializationContext.getOutputMapping(), queryColumnSet,
optimizerContext);
// add agg push down rewrite info
rewriteContext.setAggregatePushDownContext(mvRewriteContext.getAggregatePushDownContext());

Expand Down
Loading
Loading