Skip to content

Commit

Permalink
[CALCITE-4838] Add RoundingMode in RelDataTypeSystem to specify the r…
Browse files Browse the repository at this point in the history
…ounding behavior
  • Loading branch information
NobiGo committed Aug 26, 2024
1 parent 99a0df1 commit 6c7685c
Show file tree
Hide file tree
Showing 11 changed files with 626 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -456,45 +456,51 @@ private Expression getConvertExpression(
&& scale != RelDataType.SCALE_NOT_SPECIFIED) {
if (sourceType.getFamily() == SqlTypeFamily.CHARACTER) {
return Expressions.call(
BuiltInMethod.CHAR_DECIMAL_CAST.method,
BuiltInMethod.CHAR_DECIMAL_CAST_ROUNDING_MODE.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale));
Expressions.constant(scale),
Expressions.constant(typeFactory.getTypeSystem().roundingMode()));
} else if (sourceType.getFamily() == SqlTypeFamily.INTERVAL_DAY_TIME) {
return Expressions.call(
BuiltInMethod.SHORT_INTERVAL_DECIMAL_CAST.method,
BuiltInMethod.SHORT_INTERVAL_DECIMAL_CAST_ROUNDING_MODE.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale),
Expressions.constant(sourceType.getSqlTypeName().getEndUnit().multiplier));
Expressions.constant(sourceType.getSqlTypeName().getEndUnit().multiplier),
Expressions.constant(typeFactory.getTypeSystem().roundingMode()));
} else if (sourceType.getFamily() == SqlTypeFamily.INTERVAL_YEAR_MONTH) {
return Expressions.call(
BuiltInMethod.LONG_INTERVAL_DECIMAL_CAST.method,
BuiltInMethod.LONG_INTERVAL_DECIMAL_CAST_ROUNDING_MODE.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale),
Expressions.constant(sourceType.getSqlTypeName().getEndUnit().multiplier));
Expressions.constant(sourceType.getSqlTypeName().getEndUnit().multiplier),
Expressions.constant(typeFactory.getTypeSystem().roundingMode()));
} else if (sourceType.getSqlTypeName() == SqlTypeName.DECIMAL) {
// Cast from DECIMAL to DECIMAL, may adjust scale and precision.
return Expressions.call(
BuiltInMethod.DECIMAL_DECIMAL_CAST.method,
BuiltInMethod.DECIMAL_DECIMAL_CAST_ROUNDING_MODE.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale));
Expressions.constant(scale),
Expressions.constant(typeFactory.getTypeSystem().roundingMode()));
} else if (SqlTypeName.INT_TYPES.contains(sourceType.getSqlTypeName())) {
// Cast from INTEGER to DECIMAL, check for overflow
return Expressions.call(
BuiltInMethod.INTEGER_DECIMAL_CAST.method,
BuiltInMethod.INTEGER_DECIMAL_CAST_ROUNDING_MODE.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale));
Expressions.constant(scale),
Expressions.constant(typeFactory.getTypeSystem().roundingMode()));
} else if (SqlTypeName.APPROX_TYPES.contains(sourceType.getSqlTypeName())) {
// Cast from FLOAT/DOUBLE to DECIMAL
return Expressions.call(
BuiltInMethod.FP_DECIMAL_CAST.method,
BuiltInMethod.FP_DECIMAL_CAST_ROUNDING_MODE.method,
operand,
Expressions.constant(precision),
Expressions.constant(scale));
Expressions.constant(scale),
Expressions.constant(typeFactory.getTypeSystem().roundingMode()));
}
}
return defaultExpression.get();
Expand All @@ -505,9 +511,9 @@ private Expression getConvertExpression(
case SMALLINT: {
if (SqlTypeName.NUMERIC_TYPES.contains(sourceType.getSqlTypeName())) {
return Expressions.call(
BuiltInMethod.INTEGER_CAST.method,
BuiltInMethod.INTEGER_CAST_ROUNDING_MODE.method,
Expressions.constant(Primitive.of(typeFactory.getJavaClass(targetType))),
operand);
operand, Expressions.constant(typeFactory.getTypeSystem().roundingMode()));
}
return defaultExpression.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import org.checkerframework.checker.nullness.qual.Nullable;

import java.math.RoundingMode;

/** Implementation of {@link org.apache.calcite.rel.type.RelDataTypeSystem}
* that sends all methods to an underlying object. */
public class DelegatingTypeSystem implements RelDataTypeSystem {
Expand Down Expand Up @@ -50,6 +52,10 @@ protected DelegatingTypeSystem(RelDataTypeSystem typeSystem) {
return typeSystem.getMaxNumericPrecision();
}

@Override public RoundingMode roundingMode() {
return typeSystem.roundingMode();
}

@Override public @Nullable String getLiteral(SqlTypeName typeName, boolean isPrefix) {
return typeSystem.getLiteral(typeName, isPrefix);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import org.checkerframework.checker.nullness.qual.Nullable;

import java.math.RoundingMode;

/**
* Type system.
*
Expand Down Expand Up @@ -60,6 +62,9 @@ public interface RelDataTypeSystem {
/** Returns the maximum precision of a NUMERIC or DECIMAL type. */
int getMaxNumericPrecision();

/** Returns the rounding behavior for numerical operations capable of discarding precision. */
RoundingMode roundingMode();

/** Returns the LITERAL string for the type, either PREFIX/SUFFIX. */
@Nullable String getLiteral(SqlTypeName typeName, boolean isPrefix);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import org.checkerframework.checker.nullness.qual.Nullable;

import java.math.RoundingMode;

/** Default implementation of
* {@link org.apache.calcite.rel.type.RelDataTypeSystem},
* providing parameters from the SQL standard.
Expand Down Expand Up @@ -160,6 +162,10 @@ public abstract class RelDataTypeSystemImpl implements RelDataTypeSystem {
return 19;
}

@Override public RoundingMode roundingMode() {
return RoundingMode.DOWN;
}

@Override public @Nullable String getLiteral(SqlTypeName typeName, boolean isPrefix) {
switch (typeName) {
case VARBINARY:
Expand Down
21 changes: 17 additions & 4 deletions core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.charset.Charset;
import java.sql.ResultSet;
import java.sql.Time;
Expand Down Expand Up @@ -159,10 +160,16 @@ public enum BuiltInMethod {
AS_QUERYABLE(Enumerable.class, "asQueryable"),
ABSTRACT_ENUMERABLE_CTOR(AbstractEnumerable.class),
CHAR_DECIMAL_CAST(Primitive.class, "charToDecimalCast", String.class, int.class, int.class),
CHAR_DECIMAL_CAST_ROUNDING_MODE(Primitive.class, "charToDecimalCast",
String.class, int.class, int.class, RoundingMode.class),
SHORT_INTERVAL_DECIMAL_CAST(Primitive.class, "shortIntervalToDecimalCast",
Long.class, int.class, int.class, BigDecimal.class),
SHORT_INTERVAL_DECIMAL_CAST_ROUNDING_MODE(Primitive.class, "shortIntervalToDecimalCast",
Long.class, int.class, int.class, BigDecimal.class, RoundingMode.class),
LONG_INTERVAL_DECIMAL_CAST(Primitive.class, "longIntervalToDecimalCast",
Integer.class, int.class, int.class, BigDecimal.class),
LONG_INTERVAL_DECIMAL_CAST_ROUNDING_MODE(Primitive.class, "longIntervalToDecimalCast",
Integer.class, int.class, int.class, BigDecimal.class, RoundingMode.class),
INTO(ExtendedEnumerable.class, "into", Collection.class),
REMOVE_ALL(ExtendedEnumerable.class, "removeAll", Collection.class),
SCHEMA_GET_SUB_SCHEMA(Schema.class, "getSubSchema", String.class),
Expand Down Expand Up @@ -310,11 +317,17 @@ public enum BuiltInMethod {
AS_LIST(Primitive.class, "asList", Object.class),
DECIMAL_DECIMAL_CAST(Primitive.class, "decimalDecimalCast",
BigDecimal.class, int.class, int.class),
INTEGER_DECIMAL_CAST(Primitive.class, "integerDecimalCast",
Number.class, int.class, int.class),
FP_DECIMAL_CAST(Primitive.class, "fpDecimalCast",
Number.class, int.class, int.class),
DECIMAL_DECIMAL_CAST_ROUNDING_MODE(Primitive.class, "decimalDecimalCast",
BigDecimal.class, int.class, int.class, RoundingMode.class),
INTEGER_DECIMAL_CAST(Primitive.class, "integerDecimalCast", Number.class, int.class, int.class),
INTEGER_DECIMAL_CAST_ROUNDING_MODE(Primitive.class, "integerDecimalCast",
Number.class, int.class, int.class, RoundingMode.class),
FP_DECIMAL_CAST(Primitive.class, "fpDecimalCast", Number.class, int.class, int.class),
FP_DECIMAL_CAST_ROUNDING_MODE(Primitive.class, "fpDecimalCast",
Number.class, int.class, int.class, RoundingMode.class),
INTEGER_CAST(Primitive.class, "integerCast", Primitive.class, Object.class),
INTEGER_CAST_ROUNDING_MODE(Primitive.class, "integerCast",
Primitive.class, Object.class, RoundingMode.class),
MEMORY_GET0(MemoryFactory.Memory.class, "get"),
MEMORY_GET1(MemoryFactory.Memory.class, "get", int.class),
ENUMERATOR_CURRENT(Enumerator.class, "current"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ public static void main(String[] args) throws Exception {
.with(CalciteConnectionProperty.FUN, SqlLibrary.CALCITE.fun)
.with(CalciteAssert.Config.SCOTT)
.connect();
case "scott-rounding-half-up":
return CalciteAssert.that()
.with(CalciteConnectionProperty.PARSER_FACTORY,
ExtensionDdlExecutor.class.getName() + "#PARSER_FACTORY")
.with(CalciteConnectionProperty.FUN, SqlLibrary.CALCITE.fun)
.with(CalciteConnectionProperty.TYPE_SYSTEM,
CustomRelDataTypeSystem.class.getName() + "#ROUNDING_MODE_HALF_UP")
.with(CalciteAssert.Config.SCOTT)
.connect();
case "scott-lenient":
// Same as "scott", but uses LENIENT conformance.
// TODO: add a way to change conformance without defining a new
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.test;

import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rel.type.RelDataTypeSystemImpl;

import java.math.RoundingMode;

/**
* Custom type system only for Quidem test.
*
* <p> Specify the rounding behaviour. In the default implementation,
* the rounding mode is {@link RoundingMode#DOWN}, but here is {@link RoundingMode#HALF_UP}
*
* <p>The default implementation is {@link #DEFAULT}.
*/

public class CustomRelDataTypeSystem extends RelDataTypeSystemImpl {

public static final RelDataTypeSystem ROUNDING_MODE_HALF_UP = new CustomRelDataTypeSystem();

@Override public RoundingMode roundingMode() {
return RoundingMode.HALF_UP;
}
}
Loading

0 comments on commit 6c7685c

Please sign in to comment.