Skip to content

Commit

Permalink
[opt](Nereids) polish aggregate function signature matching (apache#3…
Browse files Browse the repository at this point in the history
…9352) (apache#39466)

pick from master apache#39352

use double to match string
- stddev
- stddev_samp

use largeint to match string
- group_bit_and
- group_bit_or
- group_git_xor

optimize error message
- multi_distinct_sum
  • Loading branch information
morrySnow authored Aug 16, 2024
1 parent f0da2ff commit 7fd2f96
Show file tree
Hide file tree
Showing 15 changed files with 120 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -760,13 +760,13 @@ private Expr compareLiteral(LiteralExpr first, LiteralExpr second) throws Analys
case EQ_FOR_NULL:
return new BoolLiteral(compareResult == 0);
case GE:
return new BoolLiteral(compareResult == 1 || compareResult == 0);
return new BoolLiteral(compareResult >= 0);
case GT:
return new BoolLiteral(compareResult == 1);
return new BoolLiteral(compareResult > 0);
case LE:
return new BoolLiteral(compareResult == -1 || compareResult == 0);
case LT:
return new BoolLiteral(compareResult == -1);
return new BoolLiteral(compareResult < 0);
case NE:
return new BoolLiteral(compareResult != 0);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ public class AvgWeighted extends AggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
public class BitmapAgg extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BitmapType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(BigIntType.INSTANCE)
);
FunctionSignature.ret(BitmapType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BitmapType.INSTANCE).args(TinyIntType.INSTANCE)
);

public BitmapAgg(Expression arg0) {
super("bitmap_agg", arg0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ public CollectList(boolean distinct, Expression arg) {
super("collect_list", distinct, arg);
}

@Override
public FunctionSignature computeSignature(FunctionSignature signature) {
signature = signature.withReturnType(ArrayType.of(getArgumentType(0)));
return super.computeSignature(signature);
}

/**
* withDistinctAndChildren.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ public class GroupBitAnd extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE)
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ public class GroupBitOr extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE)
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ public class GroupBitXor extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE)
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).args(TinyIntType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,16 @@
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/** MultiDistinctSum */
public class MultiDistinctSum extends NullableAggregateFunction implements UnaryExpression,
ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(DoubleType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(LargeIntType.INSTANCE)
);

public MultiDistinctSum(Expression arg0) {
super("multi_distinct_sum", true, false, arg0);
}
Expand All @@ -57,8 +48,10 @@ public MultiDistinctSum(boolean distinct, boolean alwaysNullable, Expression arg

@Override
public void checkLegalityBeforeTypeCoercion() {
if (child().getDataType().isDateLikeType()) {
throw new AnalysisException("Sum in multi distinct functions do not support Date/Datetime type");
DataType argType = child().getDataType();
if ((!argType.isNumericType() && !argType.isBooleanType() && !argType.isNullType())
|| argType.isOnlyMetricType()) {
throw new AnalysisException("sum requires a numeric or boolean parameter: " + this.toSql());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ public class Stddev extends NullableAggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ public class StddevSamp extends AggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,25 @@ public class TopNWeighted extends AggregateFunction
implements ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(BooleanType.INSTANCE))
.args(BooleanType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(TinyIntType.INSTANCE))
.args(TinyIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(SmallIntType.INSTANCE))
.args(SmallIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(IntegerType.INSTANCE))
.args(IntegerType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE))
.args(BigIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE))
.args(LargeIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
.args(FloatType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
// three arguments
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(DoubleType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DecimalV2Type.CATALOG_DEFAULT))
.args(DecimalV2Type.CATALOG_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE))
.args(LargeIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE))
.args(BigIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(IntegerType.INSTANCE))
.args(IntegerType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(SmallIntType.INSTANCE))
.args(SmallIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(TinyIntType.INSTANCE))
.args(TinyIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BooleanType.INSTANCE))
.args(BooleanType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
.args(FloatType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateType.INSTANCE))
.args(DateType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateTimeType.INSTANCE))
Expand All @@ -78,31 +79,35 @@ public class TopNWeighted extends AggregateFunction
.args(DateV2Type.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateTimeV2Type.SYSTEM_DEFAULT))
.args(DateTimeV2Type.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(CharType.SYSTEM_DEFAULT))
.args(CharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(StringType.INSTANCE))
.args(StringType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BooleanType.INSTANCE))
.args(BooleanType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(TinyIntType.INSTANCE))
.args(TinyIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(SmallIntType.INSTANCE))
.args(SmallIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(IntegerType.INSTANCE))
.args(IntegerType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE))
.args(BigIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE))
.args(LargeIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
.args(FloatType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(VarcharType.SYSTEM_DEFAULT))
.args(VarcharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(CharType.SYSTEM_DEFAULT))
.args(CharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE),

// four arguments
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
.args(DoubleType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(DecimalV2Type.CATALOG_DEFAULT,
BigIntType.INSTANCE,
IntegerType.INSTANCE,
IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE))
.args(LargeIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE))
.args(BigIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(IntegerType.INSTANCE))
.args(IntegerType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(SmallIntType.INSTANCE))
.args(SmallIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(TinyIntType.INSTANCE))
.args(TinyIntType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(BooleanType.INSTANCE))
.args(BooleanType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE))
.args(FloatType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateType.INSTANCE))
.args(DateType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(DateTimeType.INSTANCE))
Expand All @@ -114,10 +119,12 @@ public class TopNWeighted extends AggregateFunction
BigIntType.INSTANCE,
IntegerType.INSTANCE,
IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(CharType.SYSTEM_DEFAULT))
.args(CharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(StringType.INSTANCE))
.args(StringType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE)
.args(StringType.INSTANCE, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(VarcharType.SYSTEM_DEFAULT))
.args(VarcharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(ArrayType.of(CharType.SYSTEM_DEFAULT))
.args(CharType.SYSTEM_DEFAULT, BigIntType.INSTANCE, IntegerType.INSTANCE, IntegerType.INSTANCE)
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ public class Variance extends NullableAggregateFunction

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT)
);
Expand Down
Loading

0 comments on commit 7fd2f96

Please sign in to comment.