diff --git a/cpp/src/rolling/detail/range_utils.cuh b/cpp/src/rolling/detail/range_utils.cuh index df5c125fecc..01fb60bac55 100644 --- a/cpp/src/rolling/detail/range_utils.cuh +++ b/cpp/src/rolling/detail/range_utils.cuh @@ -325,10 +325,74 @@ template } } } + +template +struct unbounded_distance_functor { + unbounded_distance_functor(Grouping groups, direction direction) + : groups{groups}, direction{direction} + { + } + Grouping const groups; + direction const direction; + [[nodiscard]] __device__ size_type operator()(size_type i) const noexcept + { + auto const row_info = groups.row_info(i); + if (direction == direction::PRECEDING) { + return i - row_info.group_start + 1; + } else { + return row_info.group_end - i - 1; + } + } +}; + +template +struct current_row_distance_functor { + current_row_distance_functor(Grouping groups, + direction direction, + order order, + column_device_view::const_iterator begin) + : groups{groups}, direction{direction}, order{order}, begin{begin} + { + } + Grouping const groups; + direction const direction; + order const order; + column_device_view::const_iterator begin; + + [[nodiscard]] __device__ size_type operator()(size_type i) const noexcept + { + using Comp = comparator_t; + auto const row_info = groups.row_info(i); + bool const is_null = Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end; + if (direction == direction::PRECEDING) { + if (is_null) { + return i - row_info.null_start + 1; + } else { + return 1 + + thrust::distance( + thrust::lower_bound( + thrust::seq, begin + row_info.non_null_start, begin + i, begin[i], Comp{order}), + begin + i); + } + } else { + if (is_null) { + return row_info.null_end - i - 1; + } else { + return thrust::distance( + begin + i, + thrust::upper_bound( + thrust::seq, begin + i, begin + row_info.non_null_end, begin[i], Comp{order})) - + 1; + } + } + } +}; + /** * @brief Functor to compute distance from a given row to the edge * of the window. * + * @tparam WindowTag The type of window we're computing the distance for. * @tparam Grouping Object defining how the orderby column is * grouped. * @tparam OrderbyT Type of elements in the orderby columns. @@ -351,12 +415,12 @@ template * saturating addition/subtraction. */ template -struct distance_functor { - distance_functor(Grouping groups, - DeltaT const* row_delta, - direction const direction, - order const order, - column_device_view::const_iterator begin) +struct bounded_distance_functor { + bounded_distance_functor(Grouping groups, + DeltaT const* row_delta, + direction const direction, + order const order, + column_device_view::const_iterator begin) : groups{groups}, row_delta{row_delta}, direction{direction}, order{order}, begin{begin} { } @@ -382,23 +446,14 @@ struct distance_functor { { using Comp = comparator_t; auto const row_info = groups.row_info(i); + bool const is_null = Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end; if (direction == direction::PRECEDING) { - if constexpr (cuda::std::is_same_v) { - return i - row_info.group_start + 1; - } // TODO: If the window is BOUNDED_OPEN, what does it mean for a row to fall in the null // group? Not that important because only spark allows nulls in the orderby column, and it // doesn't have BOUNDED_OPEN windows. - if (Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end) { + if (is_null) { return i - row_info.null_start + 1; - } - if constexpr (cuda::std::is_same_v) { - return 1 + - thrust::distance( - thrust::lower_bound( - thrust::seq, begin + row_info.non_null_start, begin + i, begin[i], Comp{order}), - begin + i); - } else if constexpr (!cuda::std::is_same_v) { + } else { // The preceding endpoint is computed via row_value - delta. // When delta is positive, this can only overflow towards -infinity. // If we did overflow towards -infinity, then the value @@ -451,23 +506,11 @@ struct distance_functor { Comp{order}), begin + i); } - } else { - CUDF_UNREACHABLE("Unexpected WindowTag"); } } else { - if constexpr (cuda::std::is_same_v) { - return row_info.group_end - i - 1; - } - if (Grouping::has_nulls && i >= row_info.null_start && i < row_info.null_end) { + if (is_null) { return row_info.null_end - i - 1; - } - if constexpr (cuda::std::is_same_v) { - return thrust::distance( - begin + i, - thrust::upper_bound( - thrust::seq, begin + i, begin + row_info.non_null_end, begin[i], Comp{order})) - - 1; - } else if constexpr (!cuda::std::is_same_v) { + } else { // The following endpoint is computed via row_value + delta. // When delta is positive, this can only overflow towards +infinity. // If we did overflow towards +infinity, then the value @@ -523,8 +566,6 @@ struct distance_functor { Comp{order})) - 1; } - } else { - CUDF_UNREACHABLE("Unexpected WindowTag"); } } } @@ -579,14 +620,31 @@ struct range_window_clamper { using DeltaT = cuda::std::remove_cv_t>; auto copy_n = [&](auto&& grouping) { using Grouping = cuda::std::decay_t; - thrust::copy_n(rmm::exec_policy_nosync(stream), - cudf::detail::make_counting_transform_iterator( - 0, - distance_functor{ - grouping, d_row_delta, direction, order, d_begin}), - orderby.size(), - result->mutable_view().begin()); + if constexpr (cuda::std::is_same_v) { + thrust::copy_n(rmm::exec_policy_nosync(stream), + cudf::detail::make_counting_transform_iterator( + 0, unbounded_distance_functor{grouping, direction}), + orderby.size(), + result->mutable_view().begin()); + } else if constexpr (cuda::std::is_same_v) { + thrust::copy_n( + rmm::exec_policy_nosync(stream), + cudf::detail::make_counting_transform_iterator( + 0, + current_row_distance_functor{grouping, direction, order, d_begin}), + orderby.size(), + result->mutable_view().begin()); + } else { + thrust::copy_n(rmm::exec_policy_nosync(stream), + cudf::detail::make_counting_transform_iterator( + 0, + bounded_distance_functor{ + grouping, d_row_delta, direction, order, d_begin}), + orderby.size(), + result->mutable_view().begin()); + } }; + if (grouping.has_value()) { if (orderby.has_nulls()) { copy_n(grouped_with_nulls{nulls_at_start,