Skip to content

Commit

Permalink
support tidb truncate function (#9842)
Browse files Browse the repository at this point in the history
close #9846

Signed-off-by: guo-shaoge <[email protected]>

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
guo-shaoge and ti-chi-bot[bot] authored Feb 11, 2025
1 parent 607d850 commit 4ffbd35
Show file tree
Hide file tree
Showing 7 changed files with 1,031 additions and 97 deletions.
8 changes: 4 additions & 4 deletions dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,10 @@ const std::unordered_map<tipb::ScalarFuncSig, String> scalar_func_map({
{tipb::ScalarFuncSig::Radians, "radians"},
{tipb::ScalarFuncSig::Sin, "sin"},
{tipb::ScalarFuncSig::Tan, "tan"},
{tipb::ScalarFuncSig::TruncateInt, "trunc"},
{tipb::ScalarFuncSig::TruncateReal, "trunc"},
//{tipb::ScalarFuncSig::TruncateDecimal, "cast"},
{tipb::ScalarFuncSig::TruncateUint, "trunc"},
{tipb::ScalarFuncSig::TruncateInt, "tidbTruncateWithFrac"},
{tipb::ScalarFuncSig::TruncateReal, "tidbTruncateWithFrac"},
{tipb::ScalarFuncSig::TruncateDecimal, "tidbTruncateWithFrac"},
{tipb::ScalarFuncSig::TruncateUint, "tidbTruncateWithFrac"},

{tipb::ScalarFuncSig::LogicalAnd, "and"},
{tipb::ScalarFuncSig::LogicalOr, "or"},
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Functions/FunctionsRound.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ void registerFunctionsRound(FunctionFactory & factory)
factory.registerFunction<FunctionTrunc>("truncate", FunctionFactory::CaseInsensitive);

factory.registerFunction<FunctionTiDBRoundWithFrac>();
factory.registerFunction<FunctionTiDBTruncateWithFrac>();
}

} // namespace DB
183 changes: 146 additions & 37 deletions dbms/src/Functions/FunctionsRound.h
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ struct ConstPowOf10
static_assert(!overflow, "Computation overflows");
};

template <typename InputType, typename OutputType>
template <typename InputType, typename OutputType, bool is_tidb_truncate>
struct TiDBFloatingRound
{
static_assert(std::is_floating_point_v<InputType>);
Expand All @@ -974,8 +974,8 @@ struct TiDBFloatingRound

static OutputType eval(InputType input, FracType frac)
{
// modified from <https://github.com/pingcap/tidb/blob/26237b35f857c2388eab46f9ee3b351687143681/types/helper.go#L33-L48>.

// ported from https://github.com/pingcap/tidb/blob/26237b35f857c2388eab46f9ee3b351687143681/types/helper.go#L33-L48 and
// https://github.com/pingcap/tidb/blob/26237b35f857c2388eab46f9ee3b351687143681/types/helper.go#L50-L61.
auto value = static_cast<OutputType>(input);
auto base = 1.0;

Expand All @@ -997,9 +997,16 @@ struct TiDBFloatingRound
value = scaled_value;
}

// floating-point environment is thread-local, so `fesetround` is thread-safe.
std::fesetround(FE_TONEAREST);
value = std::nearbyint(value);
if constexpr (is_tidb_truncate)
{
value = std::trunc(value);
}
else
{
// floating-point environment is thread-local, so `fesetround` is thread-safe.
std::fesetround(FE_TONEAREST);
value = std::nearbyint(value);
}

if (frac != 0)
{
Expand All @@ -1018,7 +1025,7 @@ struct TiDBFloatingRound
}
};

template <typename InputType, typename OutputType>
template <typename InputType, typename OutputType, bool is_tidb_truncate>
struct TiDBIntegerRound
{
static_assert(is_integer_v<InputType>);
Expand Down Expand Up @@ -1061,7 +1068,7 @@ struct TiDBIntegerRound
}
}

static OutputType eval(InputType input, FracType frac)
static OutputType evalRound(InputType input, FracType frac)
{
auto value = static_cast<OutputType>(input);

Expand Down Expand Up @@ -1107,6 +1114,33 @@ struct TiDBIntegerRound
return castBack((input < 0), absolute_value);
}
}

static OutputType evalTruncate(InputType input, FracType frac)
{
// ported from https://github.com/pingcap/tidb/blob/807b8923c0181d89d4ea8e4195f9d27d299298a7/pkg/expression/builtin_math.go#L2196-L2219
const auto value = static_cast<OutputType>(input);
if (frac >= 0)
return value;
else if (frac <= -max_digits)
return 0;
else
{
// To make sure static_cast<OutputType>(Pow::result[-frac]) will not overflow.
assert(Pow::result[-frac] < std::numeric_limits<OutputType>::max());
const auto base = static_cast<OutputType>(Pow::result[-frac]);

const auto remainder = value % base;
return value - remainder;
}
}

static OutputType eval(InputType input, FracType frac)
{
if constexpr (is_tidb_truncate)
return evalTruncate(input, frac);
else
return evalRound(input, frac);
}
};

struct TiDBDecimalRoundInfo
Expand All @@ -1127,7 +1161,7 @@ struct TiDBDecimalRoundInfo
{}
};

template <typename InputType, typename OutputType>
template <typename InputType, typename OutputType, bool is_tidb_truncate>
struct TiDBDecimalRound
{
static_assert(IsDecimal<InputType>);
Expand Down Expand Up @@ -1158,10 +1192,13 @@ struct TiDBDecimalRound
auto remainder = absolute_value % base;

absolute_value -= remainder;
if (remainder >= base / 2)
if constexpr (!is_tidb_truncate)
{
// round up.
absolute_value += base;
if (remainder >= base / 2)
{
// round up.
absolute_value += base;
}
}
}

Expand Down Expand Up @@ -1198,14 +1235,21 @@ struct TiDBDecimalRound

struct TiDBRoundPrecisionInferer
{
static std::tuple<PrecType, ScaleType> infer(PrecType prec, ScaleType scale, FracType frac, bool is_const_frac)
static std::tuple<PrecType, ScaleType> infer(
PrecType prec,
ScaleType scale,
FracType frac,
bool is_const_frac,
bool is_tidb_truncate)
{
assert(prec >= scale);
PrecType int_prec = prec - scale;
ScaleType new_scale = scale;

// +1 for possible overflow, e.g. round(99999.9) => 100000
ScaleType int_prec_increment = 1;
if (is_tidb_truncate)
int_prec_increment = 0;

if (is_const_frac)
{
Expand All @@ -1219,6 +1263,14 @@ struct TiDBRoundPrecisionInferer
}

PrecType new_prec = std::min(decimal_max_prec, int_prec + int_prec_increment + new_scale);
if (new_prec == 0)
{
// new_prec can be zero when the prec is eq to scale and frac is le to zero for truncate:
// select truncate(0.22, 0) from t_col_decimal_2_2;
// Not possible for round, because int_prec_increment is 1 for round.
RUNTIME_CHECK(is_tidb_truncate && is_const_frac && frac <= 0 && prec == scale);
new_prec = 1;
}
return std::make_tuple(new_prec, new_scale);
}
};
Expand All @@ -1239,7 +1291,8 @@ template <
typename OutputType,
typename InputColumn,
typename FracColumn,
typename OutputColumn>
typename OutputColumn,
bool is_tidb_truncate>
struct TiDBRound
{
static void apply(const TiDBRoundArguments & args)
Expand Down Expand Up @@ -1273,11 +1326,14 @@ struct TiDBRound
auto frac_data = frac_column->template getValue<FracType>();

if constexpr (std::is_floating_point_v<InputType>)
output_data[0] = TiDBFloatingRound<InputType, OutputType>::eval(input_data, frac_data);
output_data[0]
= TiDBFloatingRound<InputType, OutputType, is_tidb_truncate>::eval(input_data, frac_data);
else if constexpr (IsDecimal<InputType>)
output_data[0] = TiDBDecimalRound<InputType, OutputType>::eval(input_data, frac_data, info);
output_data[0]
= TiDBDecimalRound<InputType, OutputType, is_tidb_truncate>::eval(input_data, frac_data, info);
else
output_data[0] = TiDBIntegerRound<InputType, OutputType>::eval(input_data, frac_data);
output_data[0]
= TiDBIntegerRound<InputType, OutputType, is_tidb_truncate>::eval(input_data, frac_data);
}
else
{
Expand All @@ -1287,11 +1343,17 @@ struct TiDBRound
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_floating_point_v<InputType>)
output_data[i] = TiDBFloatingRound<InputType, OutputType>::eval(input_data, frac_data[i]);
output_data[i] = TiDBFloatingRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data,
frac_data[i]);
else if constexpr (IsDecimal<InputType>)
output_data[i] = TiDBDecimalRound<InputType, OutputType>::eval(input_data, frac_data[i], info);
output_data[i] = TiDBDecimalRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data,
frac_data[i],
info);
else
output_data[i] = TiDBIntegerRound<InputType, OutputType>::eval(input_data, frac_data[i]);
output_data[i]
= TiDBIntegerRound<InputType, OutputType, is_tidb_truncate>::eval(input_data, frac_data[i]);
}
}
}
Expand All @@ -1305,11 +1367,17 @@ struct TiDBRound
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_floating_point_v<InputType>)
output_data[i] = TiDBFloatingRound<InputType, OutputType>::eval(input_data[i], frac_data);
output_data[i] = TiDBFloatingRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data);
else if constexpr (IsDecimal<InputType>)
output_data[i] = TiDBDecimalRound<InputType, OutputType>::eval(input_data[i], frac_data, info);
output_data[i] = TiDBDecimalRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data,
info);
else
output_data[i] = TiDBIntegerRound<InputType, OutputType>::eval(input_data[i], frac_data);
output_data[i]
= TiDBIntegerRound<InputType, OutputType, is_tidb_truncate>::eval(input_data[i], frac_data);
}
}
else
Expand All @@ -1320,27 +1388,34 @@ struct TiDBRound
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_floating_point_v<InputType>)
output_data[i] = TiDBFloatingRound<InputType, OutputType>::eval(input_data[i], frac_data[i]);
output_data[i] = TiDBFloatingRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data[i]);
else if constexpr (IsDecimal<InputType>)
output_data[i]
= TiDBDecimalRound<InputType, OutputType>::eval(input_data[i], frac_data[i], info);
output_data[i] = TiDBDecimalRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data[i],
info);
else
output_data[i] = TiDBIntegerRound<InputType, OutputType>::eval(input_data[i], frac_data[i]);
output_data[i] = TiDBIntegerRound<InputType, OutputType, is_tidb_truncate>::eval(
input_data[i],
frac_data[i]);
}
}
}
}
};

/**
* round(x, d) for TiDB.
* round(x, d) and truncate(x, d) for TiDB.
*/
class FunctionTiDBRoundWithFrac : public IFunction
template <typename Name, bool is_tidb_truncate>
class FunctionTiDBRoundImpl : public IFunction
{
public:
static constexpr auto name = "tidbRoundWithFrac";
static constexpr auto name = Name::name;

static FunctionPtr create(const Context &) { return std::make_shared<FunctionTiDBRoundWithFrac>(); }
static FunctionPtr create(const Context &) { return std::make_shared<FunctionTiDBRoundImpl>(); }

String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
Expand All @@ -1355,7 +1430,6 @@ class FunctionTiDBRoundWithFrac : public IFunction
// non-const frac column can generate different return types. Plese see TiDBRoundPrecisionInferer for details.
bool useDefaultImplementationForConstants() const override { return false; }

private:
static FracType getFracFromConstColumn(const ColumnConst * column)
{
using UnsignedFrac = make_unsigned_t<FracType>;
Expand Down Expand Up @@ -1383,6 +1457,7 @@ class FunctionTiDBRoundWithFrac : public IFunction
}
}

private:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
checkArguments(arguments);
Expand Down Expand Up @@ -1430,7 +1505,8 @@ class FunctionTiDBRoundWithFrac : public IFunction
else
is_const_frac = false;

auto [new_prec, new_scale] = TiDBRoundPrecisionInferer::infer(prec, scale, frac, is_const_frac);
auto [new_prec, new_scale]
= TiDBRoundPrecisionInferer::infer(prec, scale, frac, is_const_frac, is_tidb_truncate);
return createDecimal(new_prec, new_scale);
}
}
Expand Down Expand Up @@ -1566,16 +1642,44 @@ class FunctionTiDBRoundWithFrac : public IFunction
if (args.input_column->isColumnConst())
{
if (args.frac_column->isColumnConst())
TiDBRound<InputType, FracType, OutputType, ColumnConst, ColumnConst, OutputColumn>::apply(args);
TiDBRound<
InputType,
FracType,
OutputType,
ColumnConst,
ColumnConst,
OutputColumn,
is_tidb_truncate>::apply(args);
else
TiDBRound<InputType, FracType, OutputType, ColumnConst, FracColumn, OutputColumn>::apply(args);
TiDBRound<
InputType,
FracType,
OutputType,
ColumnConst,
FracColumn,
OutputColumn,
is_tidb_truncate>::apply(args);
}
else
{
if (args.frac_column->isColumnConst())
TiDBRound<InputType, FracType, OutputType, InputColumn, ColumnConst, OutputColumn>::apply(args);
TiDBRound<
InputType,
FracType,
OutputType,
InputColumn,
ColumnConst,
OutputColumn,
is_tidb_truncate>::apply(args);
else
TiDBRound<InputType, FracType, OutputType, InputColumn, FracColumn, OutputColumn>::apply(args);
TiDBRound<
InputType,
FracType,
OutputType,
InputColumn,
FracColumn,
OutputColumn,
is_tidb_truncate>::apply(args);
}

return true;
Expand Down Expand Up @@ -1622,6 +1726,9 @@ struct NameRoundDecimalToInt { static constexpr auto name = "roundDecimalToInt";
struct NameCeilDecimalToInt { static constexpr auto name = "ceilDecimalToInt"; };
struct NameFloorDecimalToInt { static constexpr auto name = "floorDecimalToInt"; };
struct NameTruncDecimalToInt { static constexpr auto name = "truncDecimalToInt"; };

struct NameTiDBRoundWithFrac { static constexpr auto name = "tidbRoundWithFrac"; };
struct NameTiDBTruncateWithFrac { static constexpr auto name = "tidbTruncateWithFrac"; };
// clang-format on

using FunctionRoundToExp2 = FunctionUnaryArithmetic<RoundToExp2Impl, NameRoundToExp2, false>;
Expand All @@ -1638,6 +1745,8 @@ using FunctionCeilDecimalToInt = FunctionRoundingDecimalToInt<NameCeilDecimalToI
using FunctionFloorDecimalToInt = FunctionRoundingDecimalToInt<NameFloorDecimalToInt, RoundingMode::Floor>;
using FunctionTruncDecimalToInt = FunctionRoundingDecimalToInt<NameTruncDecimalToInt, RoundingMode::Trunc>;

using FunctionTiDBRoundWithFrac = FunctionTiDBRoundImpl<NameTiDBRoundWithFrac, /*is_tidb_truncate=*/false>;
using FunctionTiDBTruncateWithFrac = FunctionTiDBRoundImpl<NameTiDBTruncateWithFrac, /*is_tidb_truncate=*/true>;

struct PositiveMonotonicity
{
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Functions/FunctionsString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5893,9 +5893,9 @@ class FormatImpl : public IFunction
const TiDBDecimalRoundInfo & info [[maybe_unused]])
{
if constexpr (IsDecimal<T>)
return TiDBDecimalRound<T, T>::eval(number, max_num_decimals, info);
return TiDBDecimalRound<T, T, /*is_tidb_truncate=*/false>::eval(number, max_num_decimals, info);
else if constexpr (std::is_floating_point_v<T>)
return TiDBFloatingRound<T, Float64>::eval(number, max_num_decimals);
return TiDBFloatingRound<T, Float64, /*is_tidb_truncate=*/false>::eval(number, max_num_decimals);
else
{
static_assert(std::is_integral_v<T>);
Expand Down
Loading

0 comments on commit 4ffbd35

Please sign in to comment.