Skip to content

Commit

Permalink
feat: Support decimal type for Spark in function
Browse files Browse the repository at this point in the history
  • Loading branch information
zhli1142015 committed Dec 25, 2024
1 parent b0a8908 commit 3703527
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
2 changes: 1 addition & 1 deletion velox/docs/functions/spark/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ Array Functions
.. spark:function:: in(value, array(E)) -> boolean
Returns true if value matches at least one of the elements of the array.
Supports BOOLEAN, REAL, DOUBLE, BIGINT, VARCHAR, TIMESTAMP, DATE input types.
Supports BOOLEAN, REAL, DOUBLE, BIGINT, VARCHAR, TIMESTAMP, DATE, DECIMAL input types.

.. spark:function:: shuffle(array(E), seed) -> array(E)
Expand Down
18 changes: 18 additions & 0 deletions velox/functions/sparksql/In.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,22 @@ void registerInFn(const std::string& prefix) {
{prefix + "in"});
}

void registerInFnForShortDecimal(const std::string& prefix) {
registerFunction<
InFunctionOuter<ShortDecimal<P1, S1>>::template Inner,
bool,
ShortDecimal<P1, S1>,
Array<ShortDecimal<P1, S1>>>({prefix + "in"});
}

void registerInFnForLongDecimal(const std::string& prefix) {
registerFunction<
InFunctionOuter<LongDecimal<P1, S1>>::template Inner,
bool,
LongDecimal<P1, S1>,
Array<LongDecimal<P1, S1>>>({prefix + "in"});
}

} // namespace

void registerIn(const std::string& prefix) {
Expand All @@ -143,6 +159,8 @@ void registerIn(const std::string& prefix) {
registerInFn<Varchar>(prefix);
registerInFn<Timestamp>(prefix);
registerInFn<Date>(prefix);
registerInFnForShortDecimal(prefix);
registerInFnForLongDecimal(prefix);
}

} // namespace facebook::velox::functions::sparksql
51 changes: 51 additions & 0 deletions velox/functions/sparksql/tests/InTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,57 @@ TEST_F(InTest, Bool) {
EXPECT_EQ(in<bool>(false, {false}), true);
}

TEST_F(InTest, shortDecimal) {
EXPECT_EQ(in<int64_t>(1, {1, 2}, DECIMAL(2, 1)), true);
EXPECT_EQ(in<int64_t>(2, {1, 2}, DECIMAL(10, 5)), true);
EXPECT_EQ(in<int64_t>(3, {1, 2}, DECIMAL(17, 11)), false);
EXPECT_EQ(in<int64_t>(std::nullopt, {1, 2}, DECIMAL(3, 2)), std::nullopt);
EXPECT_EQ(in<int64_t>(1, {1, std::nullopt, 2}, DECIMAL(3, 2)), true);
EXPECT_EQ(in<int64_t>(2, {1, std::nullopt, 2}, DECIMAL(3, 2)), true);
EXPECT_EQ(in<int64_t>(3, {1, std::nullopt, 2}, DECIMAL(3, 2)), std::nullopt);
EXPECT_EQ(
in<int64_t>(std::nullopt, {1, std::nullopt, 2}, DECIMAL(3, 2)),
std::nullopt);
EXPECT_EQ(
in<int64_t>(
DecimalUtil::kShortDecimalMin,
{DecimalUtil::kShortDecimalMin, DecimalUtil::kShortDecimalMax},
DECIMAL(18, 9)),
true);
EXPECT_EQ(
in<int64_t>(
DecimalUtil::kShortDecimalMax,
{DecimalUtil::kShortDecimalMin, DecimalUtil::kShortDecimalMax},
DECIMAL(18, 9)),
true);
}

TEST_F(InTest, longDecimal) {
EXPECT_EQ(in<int128_t>(1, {1, 2}, DECIMAL(21, 2)), true);
EXPECT_EQ(in<int128_t>(2, {1, 2}, DECIMAL(29, 10)), true);
EXPECT_EQ(in<int128_t>(3, {1, 2}, DECIMAL(35, 20)), false);
EXPECT_EQ(in<int128_t>(std::nullopt, {1, 2}, DECIMAL(23, 2)), std::nullopt);
EXPECT_EQ(in<int128_t>(1, {1, std::nullopt, 2}, DECIMAL(23, 2)), true);
EXPECT_EQ(in<int128_t>(2, {1, std::nullopt, 2}, DECIMAL(23, 2)), true);
EXPECT_EQ(
in<int128_t>(3, {1, std::nullopt, 2}, DECIMAL(23, 2)), std::nullopt);
EXPECT_EQ(
in<int128_t>(std::nullopt, {1, std::nullopt, 2}, DECIMAL(23, 2)),
std::nullopt);
EXPECT_EQ(
in<int128_t>(
DecimalUtil::kLongDecimalMin,
{DecimalUtil::kLongDecimalMin, DecimalUtil::kLongDecimalMax},
DECIMAL(38, 19)),
true);
EXPECT_EQ(
in<int128_t>(
DecimalUtil::kLongDecimalMax,
{DecimalUtil::kLongDecimalMin, DecimalUtil::kLongDecimalMax},
DECIMAL(38, 19)),
true);
}

TEST_F(InTest, Const) {
const auto eval = [&](const std::string& expr) {
return evaluateOnce<bool, bool>(expr, false);
Expand Down

0 comments on commit 3703527

Please sign in to comment.