Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(array): Add Presto function array_top_n #12105

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
#include "velox/expression/PrestoCastHooks.h"
#include "velox/functions/Udf.h"
#include "velox/functions/lib/CheckedArithmetic.h"
#include "velox/functions/lib/ComparatorUtil.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/type/Conversions.h"
#include "velox/type/FloatingPointUtil.h"

#include <queue>

namespace facebook::velox::functions {

template <typename TExecCtx, bool isMax>
Expand Down Expand Up @@ -729,6 +732,162 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
}
}

/// This class implements the array_top_n function.
///
/// DEFINITION:
/// array_top_n(array(T), int) -> array(T)
/// Returns the top n elements of the array in descending order.
template <typename T>
struct ArrayTopNFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Definition for primitives.
template <typename TReturn, typename TInput>
FOLLY_ALWAYS_INLINE void
call(TReturn& result, const TInput& array, int64_t n) {
VELOX_CHECK(
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));

// Define comparator that wraps built-in function for basic primitives or
// calls floating point handler for NaNs.
using facebook::velox::util::floating_point::NaNAwareGreaterThan;
struct SimpleComparator {
bool operator()(
const typename TInput::element_t& a,
const typename TInput::element_t& b) const {
if constexpr (
std::is_same_v<typename TInput::element_t, float> ||
std::is_same_v<typename TInput::element_t, double>) {
return NaNAwareGreaterThan<typename TInput::element_t>{}(a, b);
} else {
return std::greater<typename TInput::element_t>{}(a, b);
}
}
};

// Define min-heap to store the top n elements.
std::priority_queue<
typename TInput::element_t,
std::vector<typename TInput::element_t>,
SimpleComparator>
minHeap;

// Iterate through the array and push elements to the min-heap.
int numNull = 0;
for (const auto& item : array) {
if (item.has_value()) {
if (minHeap.size() < n) {
minHeap.push(item.value());
} else if (!minHeap.empty()) {
if constexpr (
std::is_same_v<typename TInput::element_t, float> ||
std::is_same_v<typename TInput::element_t, double>) {
if (NaNAwareGreaterThan<typename TInput::element_t>{}(
item.value(), minHeap.top())) {
minHeap.push(item.value());
}
} else if (item.value() > minHeap.top()) {
minHeap.push(item.value());
}
}
if (minHeap.size() > n) {
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<typename TInput::element_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& item : reversed) {
result.push_back(item);
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}

// Generic implementation.
FOLLY_ALWAYS_INLINE void call(
out_type<Array<Orderable<T1>>>& result,
const arg_type<Array<Orderable<T1>>>& array,
const int64_t n) {
VELOX_CHECK(
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));

// Define comparator to compare complex types.
struct ComplexTypeComparator {
const arg_type<Array<Orderable<T1>>>& array;
ComplexTypeComparator(const arg_type<Array<Orderable<T1>>>& array)
: array(array) {}

bool operator()(const int64_t& a, const int64_t& b) const {
static constexpr CompareFlags kFlags = {
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
return array[a].value().compare(array[b].value(), kFlags).value() > 0;
}
};

// Iterate through the array and push elements to the min-heap.
std::priority_queue<int64_t, std::vector<int64_t>, ComplexTypeComparator>
minHeap(array);
int numNull = 0;
for (int i = 0; i < array.size(); ++i) {
if (array[i].has_value()) {
if (minHeap.size() < n) {
minHeap.push(i);
} else if (!minHeap.empty()) {
static constexpr CompareFlags kFlags = {
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
if (array[i]
.value()
.compare(array[minHeap.top()].value(), kFlags)
.value() > 0) {
minHeap.push(i);
}
}
if (minHeap.size() > n) {
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<int64_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& index : reversed) {
result.push_back(array[index].value());
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}
};

template <typename T>
struct ArrayTrimFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) {
{prefix + "trim_array"});
}

template <typename T>
inline void registerArrayTopNFunction(const std::string& prefix) {
registerFunction<ArrayTopNFunction, Array<T>, Array<T>, int64_t>(
{prefix + "array_top_n"});
}

template <typename T>
inline void registerArrayRemoveNullFunctions(const std::string& prefix) {
registerFunction<ArrayRemoveNullFunction, Array<T>, Array<T>>(
Expand Down Expand Up @@ -241,6 +247,19 @@ void registerArrayFunctions(const std::string& prefix) {
Array<Varchar>,
int64_t>({prefix + "trim_array"});

registerArrayTopNFunction<int8_t>(prefix);
registerArrayTopNFunction<int16_t>(prefix);
registerArrayTopNFunction<int32_t>(prefix);
registerArrayTopNFunction<int64_t>(prefix);
registerArrayTopNFunction<int128_t>(prefix);
registerArrayTopNFunction<float>(prefix);
registerArrayTopNFunction<double>(prefix);
registerArrayTopNFunction<Varchar>(prefix);
registerArrayTopNFunction<Timestamp>(prefix);
registerArrayTopNFunction<Date>(prefix);
registerArrayTopNFunction<Varbinary>(prefix);
registerArrayTopNFunction<Orderable<T1>>(prefix);

registerArrayRemoveNullFunctions<int8_t>(prefix);
registerArrayRemoveNullFunctions<int16_t>(prefix);
registerArrayRemoveNullFunctions<int32_t>(prefix);
Expand Down
Loading
Loading