Skip to content

Commit

Permalink
Refactor a bit to reduce specialisations
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Feb 27, 2025
1 parent e1e535e commit de42bc2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 24 deletions.
74 changes: 56 additions & 18 deletions cpp/src/rolling/detail/range_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ struct current_row_distance_functor {
*
* See `saturating_op` for details of the implementation of saturating addition/subtraction.
*/
template <typename WindowType, typename Grouping, typename OrderbyT, typename DeltaT>
template <typename Grouping, typename OrderbyT, typename DeltaT, typename WindowType>
struct bounded_distance_functor {
static_assert(cuda::std::is_same_v<WindowType, bounded_open> ||
cuda::std::is_same_v<WindowType, bounded_closed>,
Expand Down Expand Up @@ -472,6 +472,55 @@ struct range_window_clamper {
cuda::std::is_same_v<WindowType, bounded_closed> ||
cuda::std::is_same_v<WindowType, bounded_open>,
"Invalid WindowType descriptor");
template <typename Grouping>
void expand_unbounded(Grouping grouping,
direction direction,
size_type size,
mutable_column_view& result,
rmm::cuda_stream_view stream) const
{
thrust::copy_n(rmm::exec_policy_nosync(stream),
cudf::detail::make_counting_transform_iterator(
0, unbounded_distance_functor{grouping, direction}),
size,
result.begin<size_type>());
}

template <typename Grouping, typename OrderbyT>
void expand_current_row(Grouping grouping,
column_device_view::const_iterator<OrderbyT> begin,
direction direction,
order order,
size_type size,
mutable_column_view& result,
rmm::cuda_stream_view stream) const
{
thrust::copy_n(rmm::exec_policy_nosync(stream),
cudf::detail::make_counting_transform_iterator(
0, current_row_distance_functor{grouping, direction, order, begin}),
size,
result.begin<size_type>());
}

template <typename Grouping, typename OrderbyT, typename DeltaT>
void expand_bounded(Grouping grouping,
direction direction,
order order,
column_device_view::const_iterator<OrderbyT> begin,
DeltaT const* row_delta,
size_type size,
mutable_column_view& result,
rmm::cuda_stream_view stream) const
{
thrust::copy_n(rmm::exec_policy_nosync(stream),
cudf::detail::make_counting_transform_iterator(
0,
bounded_distance_functor<Grouping, OrderbyT, DeltaT, WindowType>{
grouping, direction, order, begin, row_delta}),
size,
result.begin<size_type>());
}

/**
* @brief Compute the window bounds (possibly grouped) for an orderby column.
*
Expand Down Expand Up @@ -504,27 +553,16 @@ struct range_window_clamper {
auto d_orderby = column_device_view::create(orderby, stream);
auto d_begin = d_orderby->begin<OrderbyT>();
auto expand = [&](auto&& grouping) {
using Grouping = cuda::std::decay_t<decltype(grouping)>;
static_assert(cuda::std::is_same_v<Grouping, ungrouped> ||
cuda::std::is_same_v<Grouping, grouped> ||
cuda::std::is_same_v<Grouping, ungrouped_with_nulls> ||
cuda::std::is_same_v<Grouping, grouped_with_nulls>,
"Invalid grouping descriptor");
auto copy_n = [&](auto&& functor) {
thrust::copy_n(rmm::exec_policy_nosync(stream),
cudf::detail::make_counting_transform_iterator(0, functor),
orderby.size(),
result->mutable_view().begin<size_type>());
};
auto result_view = result->mutable_view();
if constexpr (cuda::std::is_same_v<WindowType, unbounded>) {
copy_n(unbounded_distance_functor{grouping, direction});
expand_unbounded(grouping, direction, orderby.size(), result_view, stream);
} else if constexpr (cuda::std::is_same_v<WindowType, current_row>) {
copy_n(current_row_distance_functor{grouping, direction, order, d_begin});
expand_current_row(
grouping, d_begin, direction, order, orderby.size(), result_view, stream);
} else {
auto const* d_row_delta = dynamic_cast<ScalarT const*>(row_delta)->data();
using DeltaT = cuda::std::remove_cv_t<cuda::std::remove_pointer_t<decltype(d_row_delta)>>;
copy_n(bounded_distance_functor<WindowType, Grouping, OrderbyT, DeltaT>{
grouping, direction, order, d_begin, d_row_delta});
expand_bounded(
grouping, direction, order, d_begin, d_row_delta, orderby.size(), result_view, stream);
}
};

Expand Down
12 changes: 6 additions & 6 deletions cpp/src/rolling/detail/rolling_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ namespace detail::rolling {
* @brief Information about group bounds of the current row's group.
*/
struct range_group_info {
size_type group_start;
size_type group_end;
size_type null_start;
size_type null_end;
size_type non_null_start;
size_type non_null_end;
size_type const group_start;
size_type const group_end;
size_type const null_start;
size_type const null_end;
size_type const non_null_start;
size_type const non_null_end;
};

/**
Expand Down

0 comments on commit de42bc2

Please sign in to comment.