Skip to content

Commit

Permalink
Grouping row info is named struct
Browse files Browse the repository at this point in the history
Move impls to rolling_utils.cuh

We can also drop the null count from the return value.
  • Loading branch information
wence- committed Feb 24, 2025
1 parent fec5788 commit a3eab98
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 136 deletions.
168 changes: 43 additions & 125 deletions cpp/src/rolling/detail/range_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,95 +49,6 @@ namespace CUDF_EXPORT cudf {
namespace detail {
namespace rolling {

/**
* @brief A group descriptor for an ungrouped rolling window with nulls
*
* @param nulls_at_start Are the nulls at the start or end?
* @param num_rows The number of rows to be rolled over.
* @param null_count The number of nulls.
*
* @note This is used for uniformity of interface between grouped and ungrouped
* iterator construction.
*/
struct ungrouped_with_nulls {
bool nulls_at_start;
cudf::size_type num_rows;
cudf::size_type null_count;

static constexpr bool has_nulls{true};
/**
* @copydoc ungrouped::row_info
*/
[[nodiscard]] __device__ constexpr cuda::std::
tuple<size_type, size_type, size_type, size_type, size_type, size_type, size_type>
row_info(size_type i) const noexcept
{
if (nulls_at_start) {
return {null_count, 0, num_rows, 0, null_count, null_count, num_rows};
} else {
return {null_count,
num_rows,
null_count,
num_rows - null_count,
num_rows,
0,
num_rows - null_count};
}
}
};

/**
* @brief A group descriptor for a grouped rolling window with nulls
*
* @param nulls_at_start Are the nulls at the start of each group?
* @param labels The group labels, mapping from input rows to group.
* @param offsets The group offsets providing the endpoints of each group.
* @param null_counts The null counts per group.
* @param orderby The orderby column, sorted groupwise.
*
* @note This is used for uniformity of interface between grouped and ungrouped
* iterator construction.
*/
struct grouped_with_nulls {
bool nulls_at_start;
// Taking raw pointers here to avoid stealing three registers for the sizes which are never
// needed.
cudf::size_type const* labels;
cudf::size_type const* offsets;
cudf::size_type const* null_counts;

static constexpr bool has_nulls{true};
/**
* @copydoc ungrouped::row_info
*/
[[nodiscard]] __device__ constexpr cuda::std::
tuple<size_type, size_type, size_type, size_type, size_type, size_type, size_type>
row_info(size_type i) const noexcept
{
auto const label = labels[i];
auto const null_count = null_counts[label];
auto const group_start = offsets[label];
auto const group_end = offsets[label + 1];
if (nulls_at_start) {
return {null_count,
group_start,
group_end,
group_start,
group_start + null_count,
group_start + null_count,
group_end};
} else {
return {null_count,
group_start,
group_end,
group_end - null_count,
group_end,
group_start,
group_end - null_count};
}
}
};

/*
* Spark requires that orderby columns with floating point type have a
* total order on floats where all NaNs compare equal to one-another,
Expand Down Expand Up @@ -445,14 +356,8 @@ struct distance_functor {
DeltaT const* row_delta,
direction const direction,
order const order,
column_device_view::const_iterator<OrderbyT> begin,
column_device_view::const_iterator<OrderbyT> end)
: groups{groups},
row_delta{row_delta},
direction{direction},
order{order},
begin{begin},
end{end}
column_device_view::const_iterator<OrderbyT> begin)
: groups{groups}, row_delta{row_delta}, direction{direction}, order{order}, begin{begin}
{
}
Grouping groups; ///< Group information to determine bounds on current row's window
Expand All @@ -464,7 +369,6 @@ struct distance_functor {
DeltaT const* row_delta; ///< Delta from current row that defines the interval endpoint. This
///< pointer is null for UNBOUNDED and CURRENT_ROW windows.
column_device_view::const_iterator<OrderbyT> begin; ///< Iterator to beginning of orderby column
column_device_view::const_iterator<OrderbyT> end; ///< Iterator to end of orderby column
direction const direction;
order const order;

Expand All @@ -476,19 +380,24 @@ struct distance_functor {
*/
[[nodiscard]] __device__ size_type operator()(size_type i) const
{
using Comp = comparator_t<OrderbyT, WindowTag>;
auto const [null_count, group_start, group_end, null_start, null_end, start, end] =
groups.row_info(i);
using Comp = comparator_t<OrderbyT, WindowTag>;
auto const row_info = groups.row_info(i);
if (direction == direction::PRECEDING) {
if constexpr (cuda::std::is_same_v<WindowTag, unbounded>) { return i - group_start + 1; }
if constexpr (cuda::std::is_same_v<WindowTag, unbounded>) {
return i - row_info.group_start + 1;
}
// TODO: If the window is BOUNDED_OPEN, what does it mean for a row to fall in the null
// group? Not that important because only spark allows nulls in the orderby column, and it
// doesn't have BOUNDED_OPEN windows.
if (Grouping::has_nulls && i >= null_start && i < null_end) { return i - null_start + 1; }
if (Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end) {
return i - row_info.null_start + 1;
}
if constexpr (cuda::std::is_same_v<WindowTag, current_row>) {
return 1 + thrust::distance(thrust::lower_bound(
thrust::seq, begin + start, begin + i, begin[i], Comp{order}),
begin + i);
return 1 +
thrust::distance(
thrust::lower_bound(
thrust::seq, begin + row_info.non_null_start, begin + i, begin[i], Comp{order}),
begin + i);
} else if constexpr (!cuda::std::is_same_v<WindowTag, unbounded>) {
// The preceding endpoint is computed via row_value - delta.
// When delta is positive, this can only overflow towards -infinity.
Expand Down Expand Up @@ -520,36 +429,43 @@ struct distance_functor {
if (*row_delta > DeltaT{0}) {
return 1 + thrust::distance(
thrust::lower_bound(thrust::seq,
begin + start,
begin + end,
begin + row_info.non_null_start,
begin + row_info.non_null_end,
value,
comparator_t<OrderbyT, bounded_closed>{order}),
begin + i);
} else {
return 1 + thrust::distance(
thrust::lower_bound(thrust::seq,
begin + start,
begin + end,
begin + row_info.non_null_start,
begin + row_info.non_null_end,
value,
comparator_t<OrderbyT, bounded_open>{order}),
begin + i);
}
} else {
return 1 +
thrust::distance(
thrust::lower_bound(thrust::seq, begin + start, begin + end, value, Comp{order}),
begin + i);
return 1 + thrust::distance(thrust::lower_bound(thrust::seq,
begin + row_info.non_null_start,
begin + row_info.non_null_end,
value,
Comp{order}),
begin + i);
}
} else {
CUDF_UNREACHABLE("Unexpected WindowTag");
}
} else {
if constexpr (cuda::std::is_same_v<WindowTag, unbounded>) { return group_end - i - 1; }
if (Grouping::has_nulls && i >= null_start && i < null_end) { return null_end - i - 1; }
if constexpr (cuda::std::is_same_v<WindowTag, unbounded>) {
return row_info.group_end - i - 1;
}
if (Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end) {
return row_info.null_end - i - 1;
}
if constexpr (cuda::std::is_same_v<WindowTag, current_row>) {
return thrust::distance(
begin + i,
thrust::upper_bound(thrust::seq, begin + i, begin + end, begin[i], Comp{order})) -
thrust::upper_bound(
thrust::seq, begin + i, begin + row_info.non_null_end, begin[i], Comp{order})) -
1;
} else if constexpr (!cuda::std::is_same_v<WindowTag, unbounded>) {
// The following endpoint is computed via row_value + delta.
Expand Down Expand Up @@ -583,25 +499,28 @@ struct distance_functor {
return thrust::distance(
begin + i,
thrust::upper_bound(thrust::seq,
begin + start,
begin + end,
begin + row_info.non_null_start,
begin + row_info.non_null_end,
value,
comparator_t<OrderbyT, bounded_closed>{order})) -
1;
} else {
return thrust::distance(
begin + i,
thrust::upper_bound(thrust::seq,
begin + start,
begin + end,
begin + row_info.non_null_start,
begin + row_info.non_null_end,
value,
comparator_t<OrderbyT, bounded_open>{order})) -
1;
}
} else {
return thrust::distance(begin + i,
thrust::upper_bound(
thrust::seq, begin + start, begin + end, value, Comp{order})) -
thrust::upper_bound(thrust::seq,
begin + row_info.non_null_start,
begin + row_info.non_null_end,
value,
Comp{order})) -
1;
}
} else {
Expand Down Expand Up @@ -656,7 +575,6 @@ struct range_window_clamper {
data_type(type_to_id<size_type>()), orderby.size(), mask_state::UNALLOCATED, stream, mr);
auto d_orderby = column_device_view::create(orderby, stream);
auto d_begin = d_orderby->begin<OrderbyT>();
auto d_end = d_orderby->end<OrderbyT>();
auto const* d_row_delta = row_delta ? dynamic_cast<ScalarT const*>(row_delta)->data() : nullptr;
using DeltaT = cuda::std::remove_cv_t<cuda::std::remove_pointer_t<decltype(d_row_delta)>>;
auto copy_n = [&](auto&& grouping) {
Expand All @@ -665,7 +583,7 @@ struct range_window_clamper {
cudf::detail::make_counting_transform_iterator(
0,
distance_functor<WindowTag, Grouping, OrderbyT, DeltaT>{
grouping, d_row_delta, direction, order, d_begin, d_end}),
grouping, d_row_delta, direction, order, d_begin}),
orderby.size(),
result->mutable_view().begin<size_type>());
};
Expand Down
Loading

0 comments on commit a3eab98

Please sign in to comment.