From 3b3df5eb2a15a06451aebe7f43bb394346604e29 Mon Sep 17 00:00:00 2001 From: Peter Enescu Date: Mon, 10 Feb 2025 15:00:04 -0800 Subject: [PATCH] feat(array): Add Presto function array_top_n (#12105) Summary: Adds Presto function array_top_n as a simple function in Velox. Function uses a temporary vector to store inputted values and heap sorts them up to k values (second input to function). Updates ArrayFunction.h with struct ArrayTopNFunction and adds new tester function ArrayTopNTest.cpp Differential Revision: D68031372 --- velox/functions/prestosql/ArrayFunctions.h | 159 +++++++ .../ArrayFunctionsRegistration.cpp | 19 + .../prestosql/tests/ArrayTopNTest.cpp | 417 ++++++++++++++++++ .../functions/prestosql/tests/CMakeLists.txt | 1 + 4 files changed, 596 insertions(+) create mode 100644 velox/functions/prestosql/tests/ArrayTopNTest.cpp diff --git a/velox/functions/prestosql/ArrayFunctions.h b/velox/functions/prestosql/ArrayFunctions.h index 8cc8692d9e37..fae1c0ea1f7e 100644 --- a/velox/functions/prestosql/ArrayFunctions.h +++ b/velox/functions/prestosql/ArrayFunctions.h @@ -21,11 +21,14 @@ #include "velox/expression/PrestoCastHooks.h" #include "velox/functions/Udf.h" #include "velox/functions/lib/CheckedArithmetic.h" +#include "velox/functions/lib/ComparatorUtil.h" #include "velox/functions/prestosql/json/SIMDJsonUtil.h" #include "velox/functions/prestosql/types/JsonType.h" #include "velox/type/Conversions.h" #include "velox/type/FloatingPointUtil.h" +#include + namespace facebook::velox::functions { template @@ -729,6 +732,162 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) { } } +/// This class implements the array_top_n function. +/// +/// DEFINITION: +/// array_top_n(array(T), int) -> array(T) +/// Returns the top n elements of the array in descending order. +template +struct ArrayTopNFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + // Definition for primitives. + template + FOLLY_ALWAYS_INLINE void + call(TReturn& result, const TInput& array, int64_t n) { + VELOX_CHECK( + n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n)); + + // Define comparator that wraps built-in function for basic primitives or + // calls floating point handler for NaNs. + using facebook::velox::util::floating_point::NaNAwareGreaterThan; + struct SimpleComparator { + bool operator()( + const typename TInput::element_t& a, + const typename TInput::element_t& b) const { + if constexpr ( + std::is_same_v || + std::is_same_v) { + return NaNAwareGreaterThan{}(a, b); + } else { + return std::greater{}(a, b); + } + } + }; + + // Define min-heap to store the top n elements. + std::priority_queue< + typename TInput::element_t, + std::vector, + SimpleComparator> + minHeap; + + // Iterate through the array and push elements to the min-heap. + int numNull = 0; + for (const auto& item : array) { + if (item.has_value()) { + if (minHeap.size() < n) { + minHeap.push(item.value()); + } else if (!minHeap.empty()) { + if constexpr ( + std::is_same_v || + std::is_same_v) { + if (NaNAwareGreaterThan{}( + item.value(), minHeap.top())) { + minHeap.push(item.value()); + } + } else if (item.value() > minHeap.top()) { + minHeap.push(item.value()); + } + } + if (minHeap.size() > n) { + minHeap.pop(); + } + } else { + ++numNull; + } + } + + // Reverse the min-heap to get the top n elements in descending order. + std::vector reversed(minHeap.size()); + auto index = minHeap.size(); + while (!minHeap.empty()) { + reversed[--index] = minHeap.top(); + minHeap.pop(); + } + + // Copy mutated vector to result vector up to minHeap's size items. + for (const auto& item : reversed) { + result.push_back(item); + } + + // Backfill nulls if needed. + while (result.size() < n && numNull > 0) { + result.add_null(); + --numNull; + } + } + + // Generic implementation. + FOLLY_ALWAYS_INLINE void call( + out_type>>& result, + const arg_type>>& array, + const int64_t n) { + VELOX_CHECK( + n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n)); + + // Define comparator to compare complex types. + struct ComplexTypeComparator { + const arg_type>>& array; + ComplexTypeComparator(const arg_type>>& array) + : array(array) {} + + bool operator()(const int64_t& a, const int64_t& b) const { + static constexpr CompareFlags kFlags = { + .nullHandlingMode = + CompareFlags::NullHandlingMode::kNullAsIndeterminate}; + return array[a].value().compare(array[b].value(), kFlags).value() > 0; + } + }; + + // Iterate through the array and push elements to the min-heap. + std::priority_queue, ComplexTypeComparator> + minHeap(array); + int numNull = 0; + for (int i = 0; i < array.size(); ++i) { + if (array[i].has_value()) { + if (minHeap.size() < n) { + minHeap.push(i); + } else if (!minHeap.empty()) { + static constexpr CompareFlags kFlags = { + .nullHandlingMode = + CompareFlags::NullHandlingMode::kNullAsIndeterminate}; + if (array[i] + .value() + .compare(array[minHeap.top()].value(), kFlags) + .value() > 0) { + minHeap.push(i); + } + } + if (minHeap.size() > n) { + minHeap.pop(); + } + } else { + ++numNull; + } + } + + // Reverse the min-heap to get the top n elements in descending order. + std::vector reversed(minHeap.size()); + auto index = minHeap.size(); + while (!minHeap.empty()) { + reversed[--index] = minHeap.top(); + minHeap.pop(); + } + + // Copy mutated vector to result vector up to minHeap's size items. + for (const auto& index : reversed) { + result.push_back(array[index].value()); + } + + // Backfill nulls if needed. + while (result.size() < n && numNull > 0) { + result.add_null(); + --numNull; + } + } +}; + template struct ArrayTrimFunction { VELOX_DEFINE_FUNCTION_TYPES(T); diff --git a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp index 1806421b9de4..a91df48cf1bb 100644 --- a/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp @@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) { {prefix + "trim_array"}); } +template +inline void registerArrayTopNFunction(const std::string& prefix) { + registerFunction, Array, int64_t>( + {prefix + "array_top_n"}); +} + template inline void registerArrayRemoveNullFunctions(const std::string& prefix) { registerFunction, Array>( @@ -241,6 +247,19 @@ void registerArrayFunctions(const std::string& prefix) { Array, int64_t>({prefix + "trim_array"}); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction(prefix); + registerArrayTopNFunction>(prefix); + registerArrayRemoveNullFunctions(prefix); registerArrayRemoveNullFunctions(prefix); registerArrayRemoveNullFunctions(prefix); diff --git a/velox/functions/prestosql/tests/ArrayTopNTest.cpp b/velox/functions/prestosql/tests/ArrayTopNTest.cpp new file mode 100644 index 000000000000..60b60207ca37 --- /dev/null +++ b/velox/functions/prestosql/tests/ArrayTopNTest.cpp @@ -0,0 +1,417 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/Macros.h" +#include "velox/functions/Registerer.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" + +#include +#include + +using namespace facebook::velox; +using namespace facebook::velox::test; +using facebook::velox::functions::test::FunctionBaseTest; +using namespace facebook::velox::functions::test; + +namespace { + +class ArrayTopNTest : public FunctionBaseTest {}; + +TEST_F(ArrayTopNTest, happyPath) { + auto input = makeArrayVectorFromJson({ + "[1, 2, 3]", + "[4, 5, 6]", + "[7, 8, 9]", + }); + + auto expected_result = + makeArrayVectorFromJson({"[3]", "[6]", "[9]"}); + auto result = evaluate("array_top_n(c0, 1)", makeRowVector({input})); + assertEqualVectors(expected_result, result); + + expected_result = + makeArrayVectorFromJson({"[3, 2]", "[6, 5]", "[9, 8]"}); + result = evaluate("array_top_n(c0, 2)", makeRowVector({input})); + assertEqualVectors(expected_result, result); + + expected_result = + makeArrayVectorFromJson({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"}); + result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); + assertEqualVectors(expected_result, result); + + result = evaluate("array_top_n(c0, 5)", makeRowVector({input})); + assertEqualVectors(expected_result, result); +} + +TEST_F(ArrayTopNTest, nullHandler) { + // Test fully null array vector. + auto input = makeNullableArrayVector({ + {std::nullopt, std::nullopt}, + {std::nullopt, std::nullopt, std::nullopt}, + }); + auto expected = + makeArrayVectorFromJson({"[null, null]", "[null, null]"}); + auto result = evaluate("array_top_n(c0, 2)", makeRowVector({input})); + assertEqualVectors(expected, result); + + expected = + makeArrayVectorFromJson({"[null, null]", "[null, null, null]"}); + result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); + assertEqualVectors(expected, result); + + // Test null array vector with various different top n values. + input = makeArrayVectorFromJson({ + "[1, null, 2, null, 3]", + "[4, 5, null, 6, null]", + "[null, 7, null, 8, 9]", + }); + + expected = makeArrayVectorFromJson({"[3]", "[6]", "[9]"}); + result = evaluate("array_top_n(c0, 1)", makeRowVector({input})); + assertEqualVectors(expected, result); + + expected = makeArrayVectorFromJson({"[3, 2]", "[6, 5]", "[9, 8]"}); + result = evaluate("array_top_n(c0, 2)", makeRowVector({input})); + assertEqualVectors(expected, result); + + expected = + makeArrayVectorFromJson({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"}); + result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); + assertEqualVectors(expected, result); + + expected = makeArrayVectorFromJson( + {"[3, 2, 1, null]", "[6, 5, 4, null]", "[9, 8, 7, null]"}); + result = evaluate("array_top_n(c0, 4)", makeRowVector({input})); + assertEqualVectors(expected, result); + assertEqualVectors(expected, result); + + expected = makeArrayVectorFromJson( + {"[3, 2, 1, null, null]", + "[6, 5, 4, null, null]", + "[9, 8, 7, null, null]"}); + result = evaluate("array_top_n(c0, 7)", makeRowVector({input})); + assertEqualVectors(expected, result); + + // Test nullable aray vector of bigints. + input = makeNullableArrayVector( + {{1, 2, std::nullopt}, + {4, 5, std::nullopt, std::nullopt}, + {7, std::nullopt, std::nullopt, std::nullopt}}); + + expected = makeArrayVectorFromJson( + {"[2, 1, null]", "[5, 4, null]", "[7, null, null]"}); + result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); + assertEqualVectors(expected, result); + + // Test nullable aray vector of strings. + input = makeNullableArrayVector({ + {"abc123", "abc", std::nullopt, "abcd"}, + {std::nullopt, "x", "xyz123", "xyzzzz"}, + }); + + result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); + assertEqualVectors( + makeArrayVectorFromJson( + {"[\"abcd\", \"abc123\", \"abc\"]", + "[\"xyzzzz\", \"xyz123\", \"x\"]"}), + result); + result = evaluate("array_top_n(c0, 4)", makeRowVector({input})); + assertEqualVectors( + makeArrayVectorFromJson( + {"[\"abcd\", \"abc123\", \"abc\", null]", + "[\"xyzzzz\", \"xyz123\", \"x\", null]"}), + result); +} + +TEST_F(ArrayTopNTest, constant) { + // Test constant array vector and verify per row. + vector_size_t size = 1'000; + auto data = makeArrayVector({{1, 2, 3}, {4, 5, 4, 5}, {7, 7, 7, 7}}); + + auto evaluateConstant = [&](vector_size_t row, const VectorPtr& vector) { + return evaluate( + "array_top_n(c0, 2)", + makeRowVector({BaseVector::wrapInConstant(size, row, vector)})); + }; + + auto result = evaluateConstant(0, data); + auto expected = makeConstantArray(size, {3, 2}); + assertEqualVectors(expected, result); + + result = evaluateConstant(1, data); + expected = makeConstantArray(size, {5, 5}); + assertEqualVectors(expected, result); + + result = evaluateConstant(2, data); + expected = makeConstantArray(size, {7, 7}); + assertEqualVectors(expected, result); + + data = makeArrayVector( + {{1, 2, 3, 0, 1, 2, 2}, {4, 5, 4, 5, 5, 4}, {6, 6, 6, 6, 7, 8, 9, 10}}); + + auto evaluateMore = [&](vector_size_t row, const VectorPtr& vector) { + return evaluate( + "array_top_n(c0, 3)", + makeRowVector({BaseVector::wrapInConstant(size, row, vector)})); + }; + + result = evaluateMore(0, data); + expected = makeConstantArray(size, {3, 2, 2}); + assertEqualVectors(expected, result); + + result = evaluateMore(1, data); + expected = makeConstantArray(size, {5, 5, 5}); + assertEqualVectors(expected, result); + + result = evaluateMore(2, data); + expected = makeConstantArray(size, {10, 9, 8}); + assertEqualVectors(expected, result); +} + +TEST_F(ArrayTopNTest, inlineStringArrays) { + // Test inline (short) strings. + using S = StringView; + + auto input = makeNullableArrayVector({ + {}, + {S("")}, + {std::nullopt}, + {S("a"), S("b")}, + {S("a"), std::nullopt, S("b")}, + {S("a"), S("a")}, + {S("b"), S("a"), S("b"), S("a"), S("a")}, + {std::nullopt, std::nullopt}, + {S("b"), std::nullopt, S("a"), S("a"), std::nullopt, S("b")}, + }); + + auto expected = makeNullableArrayVector({ + {}, + {S("")}, + {std::nullopt}, + {S("b"), S("a")}, + {S("b"), S("a")}, + {S("a"), S("a")}, + {S("b"), S("b")}, + {std::nullopt, std::nullopt}, + {S("b"), S("b")}, + }); + + auto result = + evaluate("array_top_n(C0, 2)", makeRowVector({input})); + assertEqualVectors(expected, result); +} + +TEST_F(ArrayTopNTest, stringArrays) { + // Test non-inline (> 12 character length) strings. + using S = StringView; + + auto input = makeNullableArrayVector({ + {S("red shiny car ahead"), S("blue clear sky above")}, + {S("blue clear sky above"), + S("yellow rose flowers"), + std::nullopt, + S("blue clear sky above"), + S("orange beautiful sunset")}, + { + S("red shiny car ahead"), + std::nullopt, + S("purple is an elegant color"), + S("red shiny car ahead"), + S("green plants make us happy"), + S("purple is an elegant color"), + std::nullopt, + S("purple is an elegant color"), + }, + }); + + auto expected = makeNullableArrayVector({ + {S("red shiny car ahead"), S("blue clear sky above")}, + {S("yellow rose flowers"), + S("orange beautiful sunset"), + S("blue clear sky above")}, + {S("red shiny car ahead"), + S("red shiny car ahead"), + S("purple is an elegant color")}, + }); + + auto result = + evaluate("array_top_n(C0, 3)", makeRowVector({input})); + assertEqualVectors(expected, result); +} + +TEST_F(ArrayTopNTest, nonContiguousRows) { + auto c0 = makeFlatVector(4, [](auto row) { return row; }); + auto c1 = makeArrayVector({ + {1, 1, 2, 3, 3}, + {1, 1, 2, 3, 4, 4}, + {1, 1, 2, 3, 4, 5, 5}, + {1, 1, 2, 3, 3, 4, 5, 6, 6}, + }); + + auto c2 = makeArrayVector({ + {0, 0, 1, 1, 2, 3, 3}, + {0, 0, 1, 1, 2, 3, 4, 4}, + {0, 0, 1, 1, 2, 3, 4, 5, 5}, + {0, 0, 1, 1, 2, 3, 4, 5, 6, 6}, + }); + + auto expected = makeArrayVector({ + {3, 3}, + {4, 4}, + {5, 5}, + {6, 6}, + }); + + auto result = evaluate( + "if(c0 % 2 = 0, array_top_n(c1, 2), array_top_n(c2, 2))", + makeRowVector({c0, c1, c2})); + assertEqualVectors(expected, result); +} + +TEST_F(ArrayTopNTest, complexInput) { + // Tests array of arrays as complex input. + auto test = [this]( + const VectorPtr& inputArrayVector, + int size, + const VectorPtr& expectedArrayVector) { + auto result = evaluate( + fmt::format("array_top_n(c0, {})", size), + makeRowVector({inputArrayVector})); + assertEqualVectors(expectedArrayVector, result); + }; + + auto seedVector = + makeArrayVector({{1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}}); + + // Create arrays of array vector using above seed vector. + // [[1, 1], [2, 2], [3, 3]] + // [[4, 4], [5, 5]] + const auto arrayOfArrayInput = makeArrayVector({0, 3}, seedVector); + + // [[1, 1], [2, 2]] + // [[5, 5], [4, 4]] + const auto expected = makeArrayVector( + {0, 2}, makeArrayVector({{3, 3}, {2, 2}, {5, 5}, {4, 4}})); + + test({arrayOfArrayInput}, 2, {expected}); +} + +TEST_F(ArrayTopNTest, floatingPoints) { + static const float kNaN = std::numeric_limits::quiet_NaN(); + static const float kInfinity = std::numeric_limits::infinity(); + static const float kNegativeInfinity = + -1 * std::numeric_limits::infinity(); + + auto input = makeNullableArrayVector( + {{-1, kNaN, std::nullopt}, + {-1, -2, -3, kNaN, kNegativeInfinity}, + {kInfinity, kNaN}, + {kInfinity, kNaN, kNegativeInfinity}}); + auto expected = makeNullableArrayVector( + {{kNaN, -1, std::nullopt}, + {kNaN, -1, -2}, + {kNaN, kInfinity}, + {kNaN, kInfinity, kNegativeInfinity}}); + auto result = evaluate("array_top_n(c0, 3)", makeRowVector({input})); + + assertEqualVectors(expected, result); +} + +TEST_F(ArrayTopNTest, prestoLegacyTests) { + // Legacy test cases from Presto + assertEqualVectors( + makeArrayVectorFromJson({"[1, 1, 1]"}), + evaluate( + "array_top_n(c0, 3)", + makeRowVector({makeArrayVectorFromJson({"[1, 1, 1, 1]"})}))); + assertEqualVectors( + makeArrayVectorFromJson({"[100, 5, 3]"}), + evaluate( + "array_top_n(c0, 3)", + makeRowVector( + {makeArrayVectorFromJson({"[1, 100, 2, 5, 3]"})}))); + + assertEqualVectors( + makeArrayVectorFromJson({"[100.0, 5.0, 3.0]"}), + evaluate( + "array_top_n(c0, 3)", + makeRowVector({makeArrayVectorFromJson( + {"[1.0, 100.0, 2.0, 5.0, 3.0]"})}))); + assertEqualVectors( + makeArrayVectorFromJson({"[100.0, 5.0, 3.0]"}), + evaluate( + "array_top_n(c0, 3)", + makeRowVector( + {makeArrayVectorFromJson({"[1.0, 100, 2, 5.0, 3.0]"})}))); + + assertEqualVectors( + makeArrayVectorFromJson({"[4, 1, null]"}), + evaluate( + "array_top_n(c0, 3)", + makeRowVector({makeArrayVectorFromJson({"[4, 1, null]"})}))); + + assertEqualVectors( + makeArrayVectorFromJson({"[\"z\", \"g\", \"f\", \"d\"]"}), + evaluate( + "array_top_n(c0, 4)", + makeRowVector({makeArrayVectorFromJson( + {"[\"a\", \"z\", \"d\", \"f\", \"g\", \"b\"]"})}))); + assertEqualVectors( + makeArrayVectorFromJson( + {"[\"lorem2\", \"lorem\", \"ipsum\"]"}), + evaluate( + "array_top_n(c0, 3)", + makeRowVector({makeArrayVectorFromJson( + {"[\"foo\", \"bar\", \"lorem\", \"ipsum\", \"lorem2\"]"})}))); + assertEqualVectors( + makeArrayVectorFromJson({"[\"zzz\", \"zz\", \"g\"]"}), + evaluate( + "array_top_n(c0, 3)", + makeRowVector({makeArrayVectorFromJson( + {"[\"a\", \"zzz\", \"zz\", \"b\", \"g\", \"f\"]"})}))); + assertEqualVectors( + makeArrayVectorFromJson({"[\"d\", \"a\", \"a\"]"}), + evaluate( + "array_top_n(c0, 3)", + makeRowVector({makeArrayVectorFromJson( + {"[\"a\", \"a\", \"d\", \"a\", \"a\", \"a\"]"})}))); + + assertEqualVectors( + makeArrayVectorFromJson({"[true, true, true, false]"}), + evaluate( + "array_top_n(c0, 4)", + makeRowVector({makeArrayVectorFromJson( + {"[true, true, false, true, false]"})}))); + assertEqualVectors( + makeArrayVectorFromJson({"[]"}), + evaluate( + "array_top_n(c0, 0)", + makeRowVector({makeArrayVectorFromJson({"[1, 2, 3]"})}))); + assertEqualVectors( + makeArrayVectorFromJson({"[null, null]"}), + evaluate( + "array_top_n(c0, 2)", + makeRowVector({makeArrayVectorFromJson( + {"[null, null, null]"})}))); + VELOX_ASSERT_THROW( + evaluate( + "array_top_n(c0, -1)", + makeRowVector({makeArrayVectorFromJson({"[1, 2, 3]"})})), + "Parameter n: -1 to ARRAY_TOP_N is negative"); +} + +} // namespace diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index a6261a1564f3..097abeb1934d 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -44,6 +44,7 @@ add_executable( ArrayRemoveTest.cpp ArrayShuffleTest.cpp ArraySortTest.cpp + ArrayTopNTest.cpp ArraysOverlapTest.cpp ArraySumTest.cpp ArrayTrimTest.cpp