Skip to content

Commit

Permalink
feat: Add Presto function array_top_n (facebookincubator#12105)
Browse files Browse the repository at this point in the history
Summary:

Adds Presto function array_top_n as a simple function in Velox. Function uses a temporary vector to store inputted values and heap sorts them up to k values (second input to function).

Updates ArrayFunction.h with struct ArrayTopNFunction and adds new tester function ArrayTopNTest.cpp

Differential Revision: D68031372
  • Loading branch information
peterenescu authored and facebook-github-bot committed Jan 24, 2025
1 parent 419de77 commit c86c6b8
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 0 deletions.
50 changes: 50 additions & 0 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
#include "velox/expression/PrestoCastHooks.h"
#include "velox/functions/Udf.h"
#include "velox/functions/lib/CheckedArithmetic.h"
#include "velox/functions/lib/ComparatorUtil.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/type/Conversions.h"
#include "velox/type/FloatingPointUtil.h"

#include <queue>

namespace facebook::velox::functions {

template <typename TExecCtx, bool isMax>
Expand Down Expand Up @@ -729,6 +732,53 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
}
}

template <typename T>
struct ArrayTopNFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Definition for primitives.
template <typename TReturn, typename TInput>
FOLLY_ALWAYS_INLINE bool
call(TReturn& result, const TInput& array, int64_t n) {
// If n is invalid, exit early.
if (n <= 0) {
return false;
}

// Define min-heap to store the top n elements.
std::priority_queue<
typename TInput::element_t,
std::vector<typename TInput::element_t>,
std::greater<>>
minHeap;

// Iterate through the array and push elements to the min-heap.
for (const auto& item : array) {
if (item.has_value()) {
minHeap.push(item.value());
if (minHeap.size() > n) {
minHeap.pop();
}
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<typename TInput::element_t> reversed;
while (!minHeap.empty()) {
reversed.push_back(minHeap.top());
minHeap.pop();
}
std::reverse(reversed.begin(), reversed.end());

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& item : reversed) {
result.push_back(item);
}

return true;
}
};

template <typename T>
struct ArrayTrimFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) {
{prefix + "trim_array"});
}

template <typename T>
inline void registerArrayTopNFunction(const std::string& prefix) {
registerFunction<ArrayTopNFunction, Array<T>, Array<T>, int64_t>(
{prefix + "array_top_n"});
}

template <typename T>
inline void registerArrayRemoveNullFunctions(const std::string& prefix) {
registerFunction<ArrayRemoveNullFunction, Array<T>, Array<T>>(
Expand Down Expand Up @@ -241,6 +247,18 @@ void registerArrayFunctions(const std::string& prefix) {
Array<Varchar>,
int64_t>({prefix + "trim_array"});

registerArrayTopNFunction<int8_t>(prefix);
registerArrayTopNFunction<int16_t>(prefix);
registerArrayTopNFunction<int32_t>(prefix);
registerArrayTopNFunction<int64_t>(prefix);
registerArrayTopNFunction<int128_t>(prefix);
registerArrayTopNFunction<float>(prefix);
registerArrayTopNFunction<double>(prefix);
registerArrayTopNFunction<Varchar>(prefix);
registerArrayTopNFunction<Timestamp>(prefix);
registerArrayTopNFunction<Date>(prefix);
registerArrayTopNFunction<Varbinary>(prefix);

registerArrayRemoveNullFunctions<int8_t>(prefix);
registerArrayRemoveNullFunctions<int16_t>(prefix);
registerArrayRemoveNullFunctions<int32_t>(prefix);
Expand Down
263 changes: 263 additions & 0 deletions velox/functions/prestosql/tests/ArrayTopNTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/Macros.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"

#include <fmt/format.h>
#include <cstdint>

using namespace facebook::velox;
using namespace facebook::velox::test;
using facebook::velox::functions::test::FunctionBaseTest;
using namespace facebook::velox::functions::test;

namespace {

class ArrayTopNTest : public FunctionBaseTest {};

TEST_F(ArrayTopNTest, jsonHappyPath) {
auto input = makeArrayVectorFromJson<int32_t>({
"[1, 2, 3]",
"[4, 5, 6]",
"[7, 8, 9]",
});

auto expected_result =
makeArrayVectorFromJson<int32_t>({"[3]", "[6]", "[9]"});
auto result = evaluate("array_top_n(c0, 1)", makeRowVector({input}));
assertEqualVectors(expected_result, result);

expected_result =
makeArrayVectorFromJson<int32_t>({"[3, 2]", "[6, 5]", "[9, 8]"});
result = evaluate("array_top_n(c0, 2)", makeRowVector({input}));
assertEqualVectors(expected_result, result);

expected_result =
makeArrayVectorFromJson<int32_t>({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"});
result = evaluate("array_top_n(c0, 3)", makeRowVector({input}));
assertEqualVectors(expected_result, result);

result = evaluate("array_top_n(c0, 5)", makeRowVector({input}));
assertEqualVectors(expected_result, result);
}

TEST_F(ArrayTopNTest, nullHandler) {
// Test fully null array vector.
auto input = makeNullableArrayVector<int32_t>({
{std::nullopt, std::nullopt},
{std::nullopt, std::nullopt, std::nullopt},
});
auto expected = makeArrayVectorFromJson<int32_t>({"[]", "[]"});
auto result = evaluate("array_top_n(c0, 2)", makeRowVector({input}));
assertEqualVectors(expected, result);

// Test null array vector with various different top n values.
input = makeArrayVectorFromJson<int32_t>({
"[1, null, 2, null, 3]",
"[4, 5, null, 6, null]",
"[null, 7, null, 8, 9]",
});

expected = makeArrayVectorFromJson<int32_t>({"[3]", "[6]", "[9]"});
result = evaluate("array_top_n(c0, 1)", makeRowVector({input}));
assertEqualVectors(expected, result);

expected = makeArrayVectorFromJson<int32_t>({"[3, 2]", "[6, 5]", "[9, 8]"});
result = evaluate("array_top_n(c0, 2)", makeRowVector({input}));
assertEqualVectors(expected, result);

expected =
makeArrayVectorFromJson<int32_t>({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"});
result = evaluate("array_top_n(c0, 3)", makeRowVector({input}));
assertEqualVectors(expected, result);

expected =
makeArrayVectorFromJson<int32_t>({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"});
result = evaluate("array_top_n(c0, 4)", makeRowVector({input}));
assertEqualVectors(expected, result);

// Test nullable aray vector of bigints.
input = makeNullableArrayVector<int64_t>(
{{1, 2, std::nullopt},
{4, 5, std::nullopt, std::nullopt},
{7, std::nullopt, std::nullopt, std::nullopt}});

expected = makeArrayVectorFromJson<int64_t>({"[2, 1]", "[5, 4]", "[7]"});
result = evaluate("array_top_n(c0, 3)", makeRowVector({input}));
assertEqualVectors(expected, result);

// Test nullable aray vector of strings.
input = makeNullableArrayVector<std::string>({
{"abc123", "abc", std::nullopt, "abcd"},
{std::nullopt, "x", "xyz123", "xyzzzz"},
});
expected = makeArrayVectorFromJson<std::string>(
{"[\"abcd\", \"abc123\", \"abc\"]", "[\"xyzzzz\", \"xyz123\", \"x\"]"});

result = evaluate("array_top_n(c0, 3)", makeRowVector({input}));
assertEqualVectors(expected, result);
result = evaluate("array_top_n(c0, 4)", makeRowVector({input}));
assertEqualVectors(expected, result);
}

TEST_F(ArrayTopNTest, constant) {
// Test constant array vector and verify per row.
vector_size_t size = 1'000;
auto data = makeArrayVector<int64_t>({{1, 2, 3}, {4, 5, 4, 5}, {7, 7, 7, 7}});

auto evaluateConstant = [&](vector_size_t row, const VectorPtr& vector) {
return evaluate(
"array_top_n(c0, 2)",
makeRowVector({BaseVector::wrapInConstant(size, row, vector)}));
};

auto result = evaluateConstant(0, data);
auto expected = makeConstantArray<int64_t>(size, {3, 2});
assertEqualVectors(expected, result);

result = evaluateConstant(1, data);
expected = makeConstantArray<int64_t>(size, {5, 5});
assertEqualVectors(expected, result);

result = evaluateConstant(2, data);
expected = makeConstantArray<int64_t>(size, {7, 7});
assertEqualVectors(expected, result);

data = makeArrayVector<int64_t>(
{{1, 2, 3, 0, 1, 2, 2}, {4, 5, 4, 5, 5, 4}, {6, 6, 6, 6, 7, 8, 9, 10}});

auto evaluateMore = [&](vector_size_t row, const VectorPtr& vector) {
return evaluate(
"array_top_n(c0, 3)",
makeRowVector({BaseVector::wrapInConstant(size, row, vector)}));
};

result = evaluateMore(0, data);
expected = makeConstantArray<int64_t>(size, {3, 2, 2});
assertEqualVectors(expected, result);

result = evaluateMore(1, data);
expected = makeConstantArray<int64_t>(size, {5, 5, 5});
assertEqualVectors(expected, result);

result = evaluateMore(2, data);
expected = makeConstantArray<int64_t>(size, {10, 9, 8});
assertEqualVectors(expected, result);
}

TEST_F(ArrayTopNTest, inlineStringArrays) {
// Test inline (short) strings.
using S = StringView;

auto input = makeNullableArrayVector<StringView>({
{},
{S("")},
{std::nullopt},
{S("a"), S("b")},
{S("a"), std::nullopt, S("b")},
{S("a"), S("a")},
{S("b"), S("a"), S("b"), S("a"), S("a")},
{std::nullopt, std::nullopt},
{S("b"), std::nullopt, S("a"), S("a"), std::nullopt, S("b")},
});

auto expected = makeNullableArrayVector<StringView>({
{},
{S("")},
{},
{S("b"), S("a")},
{S("b"), S("a")},
{S("a"), S("a")},
{S("b"), S("b")},
{},
{S("b"), S("b")},
});

auto result =
evaluate<ArrayVector>("array_top_n(C0, 2)", makeRowVector({input}));
assertEqualVectors(expected, result);
}

TEST_F(ArrayTopNTest, stringArrays) {
// Test non-inline (> 12 character length) strings.
using S = StringView;

auto input = makeNullableArrayVector<StringView>({
{S("red shiny car ahead"), S("blue clear sky above")},
{S("blue clear sky above"),
S("yellow rose flowers"),
std::nullopt,
S("blue clear sky above"),
S("orange beautiful sunset")},
{
S("red shiny car ahead"),
std::nullopt,
S("purple is an elegant color"),
S("red shiny car ahead"),
S("green plants make us happy"),
S("purple is an elegant color"),
std::nullopt,
S("purple is an elegant color"),
},
});

auto expected = makeNullableArrayVector<StringView>({
{S("red shiny car ahead"), S("blue clear sky above")},
{S("yellow rose flowers"),
S("orange beautiful sunset"),
S("blue clear sky above")},
{S("red shiny car ahead"),
S("red shiny car ahead"),
S("purple is an elegant color")},
});

auto result =
evaluate<ArrayVector>("array_top_n(C0, 3)", makeRowVector({input}));
assertEqualVectors(expected, result);
}

TEST_F(ArrayTopNTest, nonContiguousRows) {
auto c0 = makeFlatVector<int64_t>(4, [](auto row) { return row; });
auto c1 = makeArrayVector<int64_t>({
{1, 1, 2, 3, 3},
{1, 1, 2, 3, 4, 4},
{1, 1, 2, 3, 4, 5, 5},
{1, 1, 2, 3, 3, 4, 5, 6, 6},
});

auto c2 = makeArrayVector<int64_t>({
{0, 0, 1, 1, 2, 3, 3},
{0, 0, 1, 1, 2, 3, 4, 4},
{0, 0, 1, 1, 2, 3, 4, 5, 5},
{0, 0, 1, 1, 2, 3, 4, 5, 6, 6},
});

auto expected = makeArrayVector<int64_t>({
{3, 3},
{4, 4},
{5, 5},
{6, 6},
});

auto result = evaluate<ArrayVector>(
"if(c0 % 2 = 0, array_top_n(c1, 2), array_top_n(c2, 2))",
makeRowVector({c0, c1, c2}));
assertEqualVectors(expected, result);
}
} // namespace
1 change: 1 addition & 0 deletions velox/functions/prestosql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ add_executable(
ArrayRemoveTest.cpp
ArrayShuffleTest.cpp
ArraySortTest.cpp
ArrayTopNTest.cpp
ArraysOverlapTest.cpp
ArraySumTest.cpp
ArrayTrimTest.cpp
Expand Down

0 comments on commit c86c6b8

Please sign in to comment.