Skip to content

Commit

Permalink
Separate distance functors by window type
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Feb 24, 2025
1 parent a3eab98 commit fb1af0f
Showing 1 changed file with 99 additions and 41 deletions.
140 changes: 99 additions & 41 deletions cpp/src/rolling/detail/range_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,74 @@ template <typename T, typename V>
}
}
}

template <typename Grouping>
struct unbounded_distance_functor {
unbounded_distance_functor(Grouping groups, direction direction)
: groups{groups}, direction{direction}
{
}
Grouping const groups;
direction const direction;
[[nodiscard]] __device__ size_type operator()(size_type i) const noexcept
{
auto const row_info = groups.row_info(i);
if (direction == direction::PRECEDING) {
return i - row_info.group_start + 1;
} else {
return row_info.group_end - i - 1;
}
}
};

template <typename Grouping, typename OrderbyT>
struct current_row_distance_functor {
current_row_distance_functor(Grouping groups,
direction direction,
order order,
column_device_view::const_iterator<OrderbyT> begin)
: groups{groups}, direction{direction}, order{order}, begin{begin}
{
}
Grouping const groups;
direction const direction;
order const order;
column_device_view::const_iterator<OrderbyT> begin;

[[nodiscard]] __device__ size_type operator()(size_type i) const noexcept
{
using Comp = comparator_t<OrderbyT, current_row>;
auto const row_info = groups.row_info(i);
bool const is_null = Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end;
if (direction == direction::PRECEDING) {
if (is_null) {
return i - row_info.null_start + 1;
} else {
return 1 +
thrust::distance(
thrust::lower_bound(
thrust::seq, begin + row_info.non_null_start, begin + i, begin[i], Comp{order}),
begin + i);
}
} else {
if (is_null) {
return row_info.null_end - i - 1;
} else {
return thrust::distance(
begin + i,
thrust::upper_bound(
thrust::seq, begin + i, begin + row_info.non_null_end, begin[i], Comp{order})) -
1;
}
}
}
};

/**
* @brief Functor to compute distance from a given row to the edge
* of the window.
*
* @tparam WindowTag The type of window we're computing the distance for.
* @tparam Grouping Object defining how the orderby column is
* grouped.
* @tparam OrderbyT Type of elements in the orderby columns.
Expand All @@ -351,12 +415,12 @@ template <typename T, typename V>
* saturating addition/subtraction.
*/
template <typename WindowTag, typename Grouping, typename OrderbyT, typename DeltaT>
struct distance_functor {
distance_functor(Grouping groups,
DeltaT const* row_delta,
direction const direction,
order const order,
column_device_view::const_iterator<OrderbyT> begin)
struct bounded_distance_functor {
bounded_distance_functor(Grouping groups,
DeltaT const* row_delta,
direction const direction,
order const order,
column_device_view::const_iterator<OrderbyT> begin)
: groups{groups}, row_delta{row_delta}, direction{direction}, order{order}, begin{begin}
{
}
Expand All @@ -382,23 +446,14 @@ struct distance_functor {
{
using Comp = comparator_t<OrderbyT, WindowTag>;
auto const row_info = groups.row_info(i);
bool const is_null = Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end;
if (direction == direction::PRECEDING) {
if constexpr (cuda::std::is_same_v<WindowTag, unbounded>) {
return i - row_info.group_start + 1;
}
// TODO: If the window is BOUNDED_OPEN, what does it mean for a row to fall in the null
// group? Not that important because only spark allows nulls in the orderby column, and it
// doesn't have BOUNDED_OPEN windows.
if (Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end) {
if (is_null) {
return i - row_info.null_start + 1;
}
if constexpr (cuda::std::is_same_v<WindowTag, current_row>) {
return 1 +
thrust::distance(
thrust::lower_bound(
thrust::seq, begin + row_info.non_null_start, begin + i, begin[i], Comp{order}),
begin + i);
} else if constexpr (!cuda::std::is_same_v<WindowTag, unbounded>) {
} else {
// The preceding endpoint is computed via row_value - delta.
// When delta is positive, this can only overflow towards -infinity.
// If we did overflow towards -infinity, then the value
Expand Down Expand Up @@ -451,23 +506,11 @@ struct distance_functor {
Comp{order}),
begin + i);
}
} else {
CUDF_UNREACHABLE("Unexpected WindowTag");
}
} else {
if constexpr (cuda::std::is_same_v<WindowTag, unbounded>) {
return row_info.group_end - i - 1;
}
if (Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end) {
if (is_null) {
return row_info.null_end - i - 1;
}
if constexpr (cuda::std::is_same_v<WindowTag, current_row>) {
return thrust::distance(
begin + i,
thrust::upper_bound(
thrust::seq, begin + i, begin + row_info.non_null_end, begin[i], Comp{order})) -
1;
} else if constexpr (!cuda::std::is_same_v<WindowTag, unbounded>) {
} else {
// The following endpoint is computed via row_value + delta.
// When delta is positive, this can only overflow towards +infinity.
// If we did overflow towards +infinity, then the value
Expand Down Expand Up @@ -523,8 +566,6 @@ struct distance_functor {
Comp{order})) -
1;
}
} else {
CUDF_UNREACHABLE("Unexpected WindowTag");
}
}
}
Expand Down Expand Up @@ -579,14 +620,31 @@ struct range_window_clamper {
using DeltaT = cuda::std::remove_cv_t<cuda::std::remove_pointer_t<decltype(d_row_delta)>>;
auto copy_n = [&](auto&& grouping) {
using Grouping = cuda::std::decay_t<decltype(grouping)>;
thrust::copy_n(rmm::exec_policy_nosync(stream),
cudf::detail::make_counting_transform_iterator(
0,
distance_functor<WindowTag, Grouping, OrderbyT, DeltaT>{
grouping, d_row_delta, direction, order, d_begin}),
orderby.size(),
result->mutable_view().begin<size_type>());
if constexpr (cuda::std::is_same_v<WindowTag, unbounded>) {
thrust::copy_n(rmm::exec_policy_nosync(stream),
cudf::detail::make_counting_transform_iterator(
0, unbounded_distance_functor<Grouping>{grouping, direction}),
orderby.size(),
result->mutable_view().begin<size_type>());
} else if constexpr (cuda::std::is_same_v<WindowTag, current_row>) {
thrust::copy_n(
rmm::exec_policy_nosync(stream),
cudf::detail::make_counting_transform_iterator(
0,
current_row_distance_functor<Grouping, OrderbyT>{grouping, direction, order, d_begin}),
orderby.size(),
result->mutable_view().begin<size_type>());
} else {
thrust::copy_n(rmm::exec_policy_nosync(stream),
cudf::detail::make_counting_transform_iterator(
0,
bounded_distance_functor<WindowTag, Grouping, OrderbyT, DeltaT>{
grouping, d_row_delta, direction, order, d_begin}),
orderby.size(),
result->mutable_view().begin<size_type>());
}
};

if (grouping.has_value()) {
if (orderby.has_nulls()) {
copy_n(grouped_with_nulls{nulls_at_start,
Expand Down

0 comments on commit fb1af0f

Please sign in to comment.