-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add Presto function array_top_n (#12105)
Summary: Adds Presto function array_top_n as a simple function in Velox. Function uses a temporary vector to store inputted values and heap sorts them up to k values (second input to function). Updates ArrayFunction.h with struct ArrayTopNFunction and adds new tester function ArrayTopNTest.cpp Differential Revision: D68031372
- Loading branch information
1 parent
bb745c1
commit 80e2065
Showing
4 changed files
with
332 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include "velox/common/base/tests/GTestUtils.h" | ||
#include "velox/functions/Macros.h" | ||
#include "velox/functions/Registerer.h" | ||
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" | ||
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" | ||
|
||
#include <fmt/format.h> | ||
#include <cstdint> | ||
|
||
using namespace facebook::velox; | ||
using namespace facebook::velox::test; | ||
using facebook::velox::functions::test::FunctionBaseTest; | ||
using namespace facebook::velox::functions::test; | ||
|
||
namespace { | ||
|
||
class ArrayTopNTest : public FunctionBaseTest {}; | ||
|
||
TEST_F(ArrayTopNTest, jsonHappyPath) { | ||
auto input = makeArrayVectorFromJson<int32_t>({ | ||
"[1, 2, 3]", | ||
"[4, 5, 6]", | ||
"[7, 8, 9]", | ||
}); | ||
|
||
auto expected_result = | ||
makeArrayVectorFromJson<int32_t>({"[3]", "[6]", "[9]"}); | ||
auto result = evaluate("array_top_n(c0, 1)", makeRowVector({input})); | ||
assertEqualVectors(expected_result, result); | ||
|
||
expected_result = | ||
makeArrayVectorFromJson<int32_t>({"[3, 2]", "[6, 5]", "[9, 8]"}); | ||
result = evaluate("array_top_n(c0, 2)", makeRowVector({input})); | ||
assertEqualVectors(expected_result, result); | ||
|
||
expected_result = | ||
makeArrayVectorFromJson<int32_t>({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"}); | ||
result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); | ||
assertEqualVectors(expected_result, result); | ||
|
||
result = evaluate("array_top_n(c0, 5)", makeRowVector({input})); | ||
assertEqualVectors(expected_result, result); | ||
} | ||
|
||
TEST_F(ArrayTopNTest, nullHandler) { | ||
// Test fully null array vector. | ||
auto input = makeNullableArrayVector<int32_t>({ | ||
{std::nullopt, std::nullopt}, | ||
{std::nullopt, std::nullopt, std::nullopt}, | ||
}); | ||
auto expected = makeArrayVectorFromJson<int32_t>({"[]", "[]"}); | ||
auto result = evaluate("array_top_n(c0, 2)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
|
||
// Test null array vector with various different top n values. | ||
input = makeArrayVectorFromJson<int32_t>({ | ||
"[1, null, 2, null, 3]", | ||
"[4, 5, null, 6, null]", | ||
"[null, 7, null, 8, 9]", | ||
}); | ||
|
||
expected = makeArrayVectorFromJson<int32_t>({"[3]", "[6]", "[9]"}); | ||
result = evaluate("array_top_n(c0, 1)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
|
||
expected = makeArrayVectorFromJson<int32_t>({"[3, 2]", "[6, 5]", "[9, 8]"}); | ||
result = evaluate("array_top_n(c0, 2)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
|
||
expected = | ||
makeArrayVectorFromJson<int32_t>({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"}); | ||
result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
|
||
expected = | ||
makeArrayVectorFromJson<int32_t>({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"}); | ||
result = evaluate("array_top_n(c0, 4)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
|
||
// Test nullable aray vector of bigints. | ||
input = makeNullableArrayVector<int64_t>( | ||
{{1, 2, std::nullopt}, | ||
{4, 5, std::nullopt, std::nullopt}, | ||
{7, std::nullopt, std::nullopt, std::nullopt}}); | ||
|
||
expected = makeArrayVectorFromJson<int64_t>({"[2, 1]", "[5, 4]", "[7]"}); | ||
result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
|
||
// Test nullable aray vector of strings. | ||
input = makeNullableArrayVector<std::string>({ | ||
{"abc123", "abc", std::nullopt, "abcd"}, | ||
{std::nullopt, "x", "xyz123", "xyzzzz"}, | ||
}); | ||
expected = makeArrayVectorFromJson<std::string>( | ||
{"[\"abcd\", \"abc123\", \"abc\"]", "[\"xyzzzz\", \"xyz123\", \"x\"]"}); | ||
|
||
result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
result = evaluate("array_top_n(c0, 4)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
} | ||
|
||
TEST_F(ArrayTopNTest, constant) { | ||
// Test constant array vector and verify per row. | ||
vector_size_t size = 1'000; | ||
auto data = makeArrayVector<int64_t>({{1, 2, 3}, {4, 5, 4, 5}, {7, 7, 7, 7}}); | ||
|
||
auto evaluateConstant = [&](vector_size_t row, const VectorPtr& vector) { | ||
return evaluate( | ||
"array_top_n(c0, 2)", | ||
makeRowVector({BaseVector::wrapInConstant(size, row, vector)})); | ||
}; | ||
|
||
auto result = evaluateConstant(0, data); | ||
auto expected = makeConstantArray<int64_t>(size, {3, 2}); | ||
assertEqualVectors(expected, result); | ||
|
||
result = evaluateConstant(1, data); | ||
expected = makeConstantArray<int64_t>(size, {5, 5}); | ||
assertEqualVectors(expected, result); | ||
|
||
result = evaluateConstant(2, data); | ||
expected = makeConstantArray<int64_t>(size, {7, 7}); | ||
assertEqualVectors(expected, result); | ||
|
||
data = makeArrayVector<int64_t>( | ||
{{1, 2, 3, 0, 1, 2, 2}, {4, 5, 4, 5, 5, 4}, {6, 6, 6, 6, 7, 8, 9, 10}}); | ||
|
||
auto evaluateMore = [&](vector_size_t row, const VectorPtr& vector) { | ||
return evaluate( | ||
"array_top_n(c0, 3)", | ||
makeRowVector({BaseVector::wrapInConstant(size, row, vector)})); | ||
}; | ||
|
||
result = evaluateMore(0, data); | ||
expected = makeConstantArray<int64_t>(size, {3, 2, 2}); | ||
assertEqualVectors(expected, result); | ||
|
||
result = evaluateMore(1, data); | ||
expected = makeConstantArray<int64_t>(size, {5, 5, 5}); | ||
assertEqualVectors(expected, result); | ||
|
||
result = evaluateMore(2, data); | ||
expected = makeConstantArray<int64_t>(size, {10, 9, 8}); | ||
assertEqualVectors(expected, result); | ||
} | ||
|
||
TEST_F(ArrayTopNTest, inlineStringArrays) { | ||
// Test inline (short) strings. | ||
using S = StringView; | ||
|
||
auto input = makeNullableArrayVector<StringView>({ | ||
{}, | ||
{S("")}, | ||
{std::nullopt}, | ||
{S("a"), S("b")}, | ||
{S("a"), std::nullopt, S("b")}, | ||
{S("a"), S("a")}, | ||
{S("b"), S("a"), S("b"), S("a"), S("a")}, | ||
{std::nullopt, std::nullopt}, | ||
{S("b"), std::nullopt, S("a"), S("a"), std::nullopt, S("b")}, | ||
}); | ||
|
||
auto expected = makeNullableArrayVector<StringView>({ | ||
{}, | ||
{S("")}, | ||
{}, | ||
{S("b"), S("a")}, | ||
{S("b"), S("a")}, | ||
{S("a"), S("a")}, | ||
{S("b"), S("b")}, | ||
{}, | ||
{S("b"), S("b")}, | ||
}); | ||
|
||
auto result = | ||
evaluate<ArrayVector>("array_top_n(C0, 2)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
} | ||
|
||
TEST_F(ArrayTopNTest, stringArrays) { | ||
// Test non-inline (> 12 character length) strings. | ||
using S = StringView; | ||
|
||
auto input = makeNullableArrayVector<StringView>({ | ||
{S("red shiny car ahead"), S("blue clear sky above")}, | ||
{S("blue clear sky above"), | ||
S("yellow rose flowers"), | ||
std::nullopt, | ||
S("blue clear sky above"), | ||
S("orange beautiful sunset")}, | ||
{ | ||
S("red shiny car ahead"), | ||
std::nullopt, | ||
S("purple is an elegant color"), | ||
S("red shiny car ahead"), | ||
S("green plants make us happy"), | ||
S("purple is an elegant color"), | ||
std::nullopt, | ||
S("purple is an elegant color"), | ||
}, | ||
}); | ||
|
||
auto expected = makeNullableArrayVector<StringView>({ | ||
{S("red shiny car ahead"), S("blue clear sky above")}, | ||
{S("yellow rose flowers"), | ||
S("orange beautiful sunset"), | ||
S("blue clear sky above")}, | ||
{S("red shiny car ahead"), | ||
S("red shiny car ahead"), | ||
S("purple is an elegant color")}, | ||
}); | ||
|
||
auto result = | ||
evaluate<ArrayVector>("array_top_n(C0, 3)", makeRowVector({input})); | ||
assertEqualVectors(expected, result); | ||
} | ||
|
||
TEST_F(ArrayTopNTest, nonContiguousRows) { | ||
auto c0 = makeFlatVector<int64_t>(4, [](auto row) { return row; }); | ||
auto c1 = makeArrayVector<int64_t>({ | ||
{1, 1, 2, 3, 3}, | ||
{1, 1, 2, 3, 4, 4}, | ||
{1, 1, 2, 3, 4, 5, 5}, | ||
{1, 1, 2, 3, 3, 4, 5, 6, 6}, | ||
}); | ||
|
||
auto c2 = makeArrayVector<int64_t>({ | ||
{0, 0, 1, 1, 2, 3, 3}, | ||
{0, 0, 1, 1, 2, 3, 4, 4}, | ||
{0, 0, 1, 1, 2, 3, 4, 5, 5}, | ||
{0, 0, 1, 1, 2, 3, 4, 5, 6, 6}, | ||
}); | ||
|
||
auto expected = makeArrayVector<int64_t>({ | ||
{3, 3}, | ||
{4, 4}, | ||
{5, 5}, | ||
{6, 6}, | ||
}); | ||
|
||
auto result = evaluate<ArrayVector>( | ||
"if(c0 % 2 = 0, array_top_n(c1, 2), array_top_n(c2, 2))", | ||
makeRowVector({c0, c1, c2})); | ||
assertEqualVectors(expected, result); | ||
} | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters