Skip to content

Commit

Permalink
[feature](function) support orthogonal_bitmap_expr_calculate & orthog…
Browse files Browse the repository at this point in the history
…onal_bitmap_expr_calculate_count for nereids (#44991)

support orthogonal_bitmap_expr_calculate &
orthogonal_bitmap_expr_calculate_count for nereids
  • Loading branch information
924060929 authored Dec 4, 2024
1 parent 5b5cbf6 commit d58a972
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ struct AggOrthBitmapExprCalBaseData {
if (first_init) {
DCHECK(argument_size > 1);
const auto& col =
assert_cast<const ColVecData&, TypeCheckOnRelease::DISABLE>(*columns[2]);
assert_cast<const ColumnString&, TypeCheckOnRelease::DISABLE>(*columns[2]);
std::string expr = col.get_data_at(row_num).to_string();
bitmap_expr_cal.bitmap_calculation_init(expr);
first_init = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctSum0;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.agg.OrthogonalBitmapExprCalculate;
import org.apache.doris.nereids.trees.expressions.functions.agg.OrthogonalBitmapExprCalculateCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.OrthogonalBitmapIntersect;
import org.apache.doris.nereids.trees.expressions.functions.agg.OrthogonalBitmapIntersectCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.OrthogonalBitmapUnionCount;
Expand Down Expand Up @@ -124,7 +126,8 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(HllUnion.class, "hll_raw_agg", "hll_union"),
agg(HllUnionAgg.class, "hll_union_agg"),
agg(IntersectCount.class, "intersect_count"),
agg(LinearHistogram.class, FunctionSet.LINEAR_HISTOGRAM),
agg(Kurt.class, "kurt", "kurt_pop", "kurtosis"),
agg(LinearHistogram.class, "linear_histogram"),
agg(MapAgg.class, "map_agg"),
agg(Max.class, "max"),
agg(MaxBy.class, "max_by"),
Expand All @@ -135,6 +138,8 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(MultiDistinctSum.class, "multi_distinct_sum"),
agg(MultiDistinctSum0.class, "multi_distinct_sum0"),
agg(Ndv.class, "approx_count_distinct", "ndv"),
agg(OrthogonalBitmapExprCalculate.class, "orthogonal_bitmap_expr_calculate"),
agg(OrthogonalBitmapExprCalculateCount.class, "orthogonal_bitmap_expr_calculate_count"),
agg(OrthogonalBitmapIntersect.class, "orthogonal_bitmap_intersect"),
agg(OrthogonalBitmapIntersectCount.class, "orthogonal_bitmap_intersect_count"),
agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"),
Expand All @@ -148,6 +153,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(Retention.class, "retention"),
agg(SequenceCount.class, "sequence_count"),
agg(SequenceMatch.class, "sequence_match"),
agg(Skew.class, "skew", "skew_pop", "skewness"),
agg(Stddev.class, "stddev_pop", "stddev"),
agg(StddevSamp.class, "stddev_samp"),
agg(Sum.class, "sum"),
Expand All @@ -157,9 +163,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(TopNWeighted.class, "topn_weighted"),
agg(Variance.class, "var_pop", "variance_pop", "variance"),
agg(VarianceSamp.class, "var_samp", "variance_samp"),
agg(WindowFunnel.class, "window_funnel"),
agg(Skew.class, "skew", "skew_pop", "skewness"),
agg(Kurt.class, "kurt", "kurt_pop", "kurtosis")
agg(WindowFunnel.class, "window_funnel")
);

public final Set<String> aggFuncNames = aggregateFunctions.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregatePhase;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
Expand Down Expand Up @@ -294,49 +295,58 @@ public List<Rule> buildRules() {
RuleType.ONE_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().isEmpty())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.ONE))
.thenApplyMulti(ctx -> onePhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITHOUT_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().isEmpty())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
.thenApplyMulti(ctx -> twoPhaseAggregateWithoutDistinct(ctx.root, ctx.connectContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
// basePattern
// .when(this::containsCountDistinctMultiExpr)
// .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
// .thenApplyMulti(ctx -> twoPhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_COUNT_DISTINCT_MULTI.build(
basePattern
.when(this::containsCountDistinctMultiExpr)
.when(agg -> agg.supportAggregatePhase(AggregatePhase.THREE))
.thenApplyMulti(ctx -> threePhaseAggregateWithCountDistinctMulti(ctx.root, ctx.cascadesContext))
),
RuleType.ONE_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
.when(agg -> agg.supportAggregatePhase(AggregatePhase.ONE))
.thenApplyMulti(ctx -> onePhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_SINGLE_DISTINCT_TO_MULTI.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1 && couldConvertToMulti(agg))
.when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
.thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
RuleType.TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() > 1
&& !containsCountDistinctMultiExpr(agg)
&& couldConvertToMulti(agg))
.when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
.thenApplyMulti(ctx -> twoPhaseAggregateWithMultiDistinct(ctx.root, ctx.connectContext))
),
// RuleType.TWO_PHASE_AGGREGATE_WITH_DISTINCT.build(
// basePattern
// .when(agg -> agg.getDistinctArguments().size() == 1)
// .when(agg -> agg.supportAggregatePhase(AggregatePhase.TWO))
// .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
// ),
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.THREE))
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
/*
Expand All @@ -361,6 +371,7 @@ && couldConvertToMulti(agg))
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.whenNot(agg -> agg.mustUseMultiDistinctAgg())
.when(agg -> agg.supportAggregatePhase(AggregatePhase.FOUR))
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireDistinctHash =
groupByAndDistinct -> RequireProperties.of(
Expand Down Expand Up @@ -408,6 +419,7 @@ && couldConvertToMulti(agg))
}
return couldConvertToMulti(agg);
})
.when(agg -> agg.supportAggregatePhase(AggregatePhase.FOUR))
.thenApplyMulti(ctx -> {
Function<List<Expression>, RequireProperties> secondPhaseRequireGroupByAndDistinctHash =
groupByAndDistinct -> RequireProperties.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ public String toString() {
return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")";
}

public boolean supportAggregatePhase(AggregatePhase aggregatePhase) {
return true;
}

public List<Expression> getDistinctArguments() {
return distinct ? getArguments() : ImmutableList.of();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.agg;

/** AggregatePhase */
public enum AggregatePhase {
ONE, TWO, THREE, FOUR
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.agg;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapEmpty;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/** OrthogonalBitmapExprCalculate */
public class OrthogonalBitmapExprCalculate extends NotNullableAggregateFunction
implements OrthogonalBitmapFunction, ExplicitlyCastableSignature {

static final List<FunctionSignature> FUNCTION_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BitmapType.INSTANCE)
.varArgs(BitmapType.INSTANCE, VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT)
);

/**
* constructor with 3 arguments.
*/
public OrthogonalBitmapExprCalculate(
Expression bitmap, Expression filterColumn, VarcharLiteral inputString) {
super("orthogonal_bitmap_expr_calculate", ExpressionUtils.mergeArguments(bitmap, filterColumn, inputString));
}

/**
* constructor with 3 arguments.
*/
public OrthogonalBitmapExprCalculate(boolean distinct,
Expression bitmap, Expression filterColumn, VarcharLiteral inputString) {
super("orthogonal_bitmap_expr_calculate", distinct,
ExpressionUtils.mergeArguments(bitmap, filterColumn, inputString));
}

@Override
public boolean supportAggregatePhase(AggregatePhase aggregatePhase) {
return aggregatePhase == AggregatePhase.TWO;
}

@Override
public Expression resultForEmptyInput() {
return new BitmapEmpty();
}

@Override
public OrthogonalBitmapExprCalculate withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 3
&& children.get(2).getDataType() instanceof CharacterType
&& children.get(2).getDataType() instanceof VarcharType);
return new OrthogonalBitmapExprCalculate(
distinct, children.get(0), children.get(1), (VarcharLiteral) children.get(2));
}

@Override
public List<FunctionSignature> getSignatures() {
return FUNCTION_SIGNATURES;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.agg;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapEmpty;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BitmapType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/** OrthogonalBitmapExprCalculateCount */
public class OrthogonalBitmapExprCalculateCount extends NotNullableAggregateFunction
implements OrthogonalBitmapFunction, ExplicitlyCastableSignature {

static final List<FunctionSignature> FUNCTION_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.varArgs(BitmapType.INSTANCE, VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT)
);

/**
* constructor with 3 arguments.
*/
public OrthogonalBitmapExprCalculateCount(
Expression bitmap, Expression filterColumn, VarcharLiteral inputString) {
super("orthogonal_bitmap_expr_calculate_count",
ExpressionUtils.mergeArguments(bitmap, filterColumn, inputString));
}

/**
* constructor with 3 arguments.
*/
public OrthogonalBitmapExprCalculateCount(boolean distinct,
Expression bitmap, Expression filterColumn, VarcharLiteral inputString) {
super("orthogonal_bitmap_expr_calculate_count", distinct,
ExpressionUtils.mergeArguments(bitmap, filterColumn, inputString));
}

@Override
public boolean supportAggregatePhase(AggregatePhase aggregatePhase) {
return aggregatePhase == AggregatePhase.TWO;
}

@Override
public Expression resultForEmptyInput() {
return new BitmapEmpty();
}

@Override
public OrthogonalBitmapExprCalculateCount withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 3
&& children.get(2).getDataType() instanceof CharacterType
&& children.get(2).getDataType() instanceof VarcharType);
return new OrthogonalBitmapExprCalculateCount(
distinct, children.get(0), children.get(1), (VarcharLiteral) children.get(2));
}

@Override
public List<FunctionSignature> getSignatures() {
return FUNCTION_SIGNATURES;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregatePhase;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.plans.Plan;
Expand Down Expand Up @@ -386,4 +388,14 @@ public void computeEqualSet(DataTrait.Builder builder) {
public void computeFd(DataTrait.Builder builder) {
builder.addFuncDepsDG(child().getLogicalProperties().getTrait());
}

/** supportAggregatePhase */
public boolean supportAggregatePhase(AggregatePhase aggregatePhase) {
for (AggregateFunction aggregateFunction : getAggregateFunctions()) {
if (!aggregateFunction.supportAggregatePhase(aggregatePhase)) {
return false;
}
}
return true;
}
}
Loading

0 comments on commit d58a972

Please sign in to comment.