Skip to content

Commit

Permalink
Backport to 2.8: Deprecate cub::Trait::CATEGORY|PRIMITIVE|NULL_TYPE (#…
Browse files Browse the repository at this point in the history
…3689) (#3703)

* Deprecate cub::Trait::CATEGORY|PRIMITIVE|NULL_TYPE (#3689)

* Fix FP type detection

---------

Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
bernhardmgruber and miscco authored Feb 6, 2025
1 parent 6b9b174 commit 23c395b
Show file tree
Hide file tree
Showing 26 changed files with 143 additions and 63 deletions.
2 changes: 2 additions & 0 deletions c2h/include/c2h/bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,12 @@ public:
};
_LIBCUDACXX_END_NAMESPACE_STD

_CCCL_SUPPRESS_DEPRECATED_PUSH
template <>
struct CUB_NS_QUALIFIER::NumericTraits<bfloat16_t>
: CUB_NS_QUALIFIER::BaseTraits<FLOATING_POINT, true, false, unsigned short, bfloat16_t>
{};
_CCCL_SUPPRESS_DEPRECATED_POP

#ifdef __GNUC__
# pragma GCC diagnostic pop
Expand Down
2 changes: 2 additions & 0 deletions c2h/include/c2h/half.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,12 @@ public:
};
_LIBCUDACXX_END_NAMESPACE_STD

_CCCL_SUPPRESS_DEPRECATED_PUSH
template <>
struct CUB_NS_QUALIFIER::NumericTraits<half_t>
: CUB_NS_QUALIFIER::BaseTraits<FLOATING_POINT, true, false, unsigned short, half_t>
{};
_CCCL_SUPPRESS_DEPRECATED_POP

#ifdef __GNUC__
# pragma GCC diagnostic pop
Expand Down
2 changes: 2 additions & 0 deletions c2h/include/c2h/test_util_vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ C2H_VEC_OVERLOAD(double, double)
/**
* Define for types
*/
_CCCL_SUPPRESS_DEPRECATED_PUSH
C2H_VEC_TRAITS_OVERLOAD(char, signed char)
C2H_VEC_TRAITS_OVERLOAD(short, short)
C2H_VEC_TRAITS_OVERLOAD(int, int)
Expand All @@ -430,5 +431,6 @@ C2H_VEC_TRAITS_OVERLOAD(ulong, unsigned long)
C2H_VEC_TRAITS_OVERLOAD(ulonglong, unsigned long long)
C2H_VEC_TRAITS_OVERLOAD(float, float)
C2H_VEC_TRAITS_OVERLOAD(double, double)
_CCCL_SUPPRESS_DEPRECATED_POP

#endif // THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
2 changes: 1 addition & 1 deletion cub/cub/agent/agent_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ struct AgentReduce
// pointer to a primitive type
static constexpr bool ATTEMPT_VECTORIZATION =
(VECTOR_LOAD_LENGTH > 1) && (ITEMS_PER_THREAD % VECTOR_LOAD_LENGTH == 0)
&& (::cuda::std::is_pointer<InputIteratorT>::value) && Traits<InputT>::PRIMITIVE;
&& (::cuda::std::is_pointer<InputIteratorT>::value) && is_primitive<InputT>::value;

static constexpr CacheLoadModifier LOAD_MODIFIER = AgentReducePolicy::LOAD_MODIFIER;

Expand Down
2 changes: 1 addition & 1 deletion cub/cub/agent/agent_reduce_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ struct AgentReduceByKey
// Whether or not the scan operation has a zero-valued identity value (true
// if we're performing addition on a primitive type)
static constexpr int HAS_IDENTITY_ZERO =
(std::is_same<ReductionOpT, ::cuda::std::plus<>>::value) && (Traits<AccumT>::PRIMITIVE);
(std::is_same<ReductionOpT, ::cuda::std::plus<>>::value) && (is_primitive<AccumT>::value);

// Cache-modified Input iterator wrapper type (for applying cache modifier)
// for keys Wrap the native input pointer with
Expand Down
8 changes: 4 additions & 4 deletions cub/cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,14 @@ using default_no_delay_t = default_no_delay_constructor_t::delay_t;

template <class T>
using default_delay_constructor_t =
::cuda::std::_If<Traits<T>::PRIMITIVE, fixed_delay_constructor_t<350, 450>, default_no_delay_constructor_t>;
::cuda::std::_If<is_primitive<T>::value, fixed_delay_constructor_t<350, 450>, default_no_delay_constructor_t>;

template <class T>
using default_delay_t = typename default_delay_constructor_t<T>::delay_t;

template <class KeyT, class ValueT>
using default_reduce_by_key_delay_constructor_t =
::cuda::std::_If<(Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16),
::cuda::std::_If<is_primitive<ValueT>::value && (sizeof(ValueT) + sizeof(KeyT) < 16),
reduce_by_key_delay_constructor_t<350, 450>,
default_delay_constructor_t<KeyValuePair<KeyT, ValueT>>>;

Expand Down Expand Up @@ -547,7 +547,7 @@ struct tile_state_with_memory_order
/**
* Tile status interface.
*/
template <typename T, bool SINGLE_WORD = Traits<T>::PRIMITIVE>
template <typename T, bool SINGLE_WORD = detail::is_primitive<T>::value>
struct ScanTileState;

/**
Expand Down Expand Up @@ -952,7 +952,7 @@ struct ScanTileState<T, false>
*/
template <typename ValueT,
typename KeyT,
bool SINGLE_WORD = (Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16)>
bool SINGLE_WORD = detail::is_primitive<ValueT>::value && (sizeof(ValueT) + sizeof(KeyT) < 16)>
struct ReduceByKeyScanTileState;

/**
Expand Down
15 changes: 13 additions & 2 deletions cub/cub/block/radix_rank_sort_operations.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include <cuda/std/cstdint>
#include <cuda/std/tuple>
#include <cuda/std/type_traits>
#include <cuda/type_traits>

CUB_NAMESPACE_BEGIN

Expand All @@ -73,9 +74,14 @@ CUB_NAMESPACE_BEGIN
and only one of them is used, the sorting works correctly. For double, the
same applies, but with 64-bit patterns.
*/
template <typename KeyT, Category TypeCategory = Traits<KeyT>::CATEGORY>
template <typename KeyT, bool IsFP = ::cuda::is_floating_point<KeyT>::value>
struct BaseDigitExtractor
{
// TODO(bgruber): sanity check, remove eventually
_CCCL_SUPPRESS_DEPRECATED_PUSH
static_assert(Traits<KeyT>::CATEGORY != FLOATING_POINT, "");
_CCCL_SUPPRESS_DEPRECATED_POP

using TraitsT = Traits<KeyT>;
using UnsignedBits = typename TraitsT::UnsignedBits;

Expand All @@ -86,8 +92,13 @@ struct BaseDigitExtractor
};

template <typename KeyT>
struct BaseDigitExtractor<KeyT, FLOATING_POINT>
struct BaseDigitExtractor<KeyT, true>
{
// TODO(bgruber): sanity check, remove eventually
_CCCL_SUPPRESS_DEPRECATED_PUSH
static_assert(Traits<KeyT>::CATEGORY == FLOATING_POINT, "");
_CCCL_SUPPRESS_DEPRECATED_POP

using TraitsT = Traits<KeyT>;
using UnsignedBits = typename TraitsT::UnsignedBits;

Expand Down
15 changes: 8 additions & 7 deletions cub/cub/device/dispatch/dispatch_transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@
# pragma system_header
#endif // no system header

#if _CCCL_CUDACC_BELOW(11, 5)
_CCCL_NV_DIAG_SUPPRESS(186)
# include <cuda_pipeline_primitives.h>
// we cannot re-enable the warning here, because it is triggered outside the translation unit
// see also: https://godbolt.org/z/1x8b4hn3G
#endif // _CCCL_CUDACC_BELOW(11, 5)

#include <cub/detail/uninitialized_copy.cuh>
#include <cub/device/dispatch/tuning/tuning_transform.cuh>
#include <cub/util_arch.cuh>
Expand Down Expand Up @@ -866,3 +859,11 @@ struct dispatch_t<RequiresStableAddress,
} // namespace transform
} // namespace detail
CUB_NAMESPACE_END

#if _CCCL_CUDACC_BELOW(11, 5)
// we need to suppress this warning which is generated outside:
// `cuda_pipeline_helpers.h(156): error #186-D: pointless comparison of unsigned integer with zero`
_CCCL_NV_DIAG_SUPPRESS(186)
// we cannot re-enable the warning anywhere, because it is triggered outside the translation unit
// see also: https://godbolt.org/z/1x8b4hn3G
#endif // _CCCL_CUDACC_BELOW(11, 5)
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/tuning/tuning_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ enum class counter_size
template <class T>
constexpr primitive_sample is_primitive_sample()
{
return Traits<T>::PRIMITIVE ? primitive_sample::yes : primitive_sample::no;
return detail::is_primitive<T>::value ? primitive_sample::yes : primitive_sample::no;
}

template <class CounterT>
Expand Down
4 changes: 2 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_reduce_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ enum class accum_size
template <class T>
constexpr primitive_key is_primitive_key()
{
return Traits<T>::PRIMITIVE ? primitive_key::yes : primitive_key::no;
return detail::is_primitive<T>::value ? primitive_key::yes : primitive_key::no;
}

template <class T>
constexpr primitive_accum is_primitive_accum()
{
return Traits<T>::PRIMITIVE ? primitive_accum::yes : primitive_accum::no;
return detail::is_primitive<T>::value ? primitive_accum::yes : primitive_accum::no;
}

template <class ReductionOpT>
Expand Down
4 changes: 2 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_run_length_encode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ enum class length_size
template <class T>
constexpr primitive_key is_primitive_key()
{
return Traits<T>::PRIMITIVE ? primitive_key::yes : primitive_key::no;
return detail::is_primitive<T>::value ? primitive_key::yes : primitive_key::no;
}

template <class T>
constexpr primitive_length is_primitive_length()
{
return Traits<T>::PRIMITIVE ? primitive_length::yes : primitive_length::no;
return detail::is_primitive<T>::value ? primitive_length::yes : primitive_length::no;
}

template <class KeyT>
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/tuning/tuning_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ enum class accum_size
template <class AccumT>
constexpr primitive_accum is_primitive_accum()
{
return Traits<AccumT>::PRIMITIVE ? primitive_accum::yes : primitive_accum::no;
return is_primitive<AccumT>::value ? primitive_accum::yes : primitive_accum::no;
}

template <class ScanOpT>
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/tuning/tuning_scan_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ enum class key_size
template <class AccumT>
constexpr primitive_accum is_primitive_accum()
{
return Traits<AccumT>::PRIMITIVE ? primitive_accum::yes : primitive_accum::no;
return detail::is_primitive<AccumT>::value ? primitive_accum::yes : primitive_accum::no;
}

template <class ScanOpT>
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/device/dispatch/tuning/tuning_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ struct sm100_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, prim
template <class InputT>
constexpr primitive is_primitive()
{
return Traits<InputT>::PRIMITIVE ? primitive::yes : primitive::no;
return detail::is_primitive<InputT>::value ? primitive::yes : primitive::no;
}

template <class FlagT>
Expand Down
4 changes: 2 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_unique_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ enum class val_size
template <class T>
constexpr primitive_key is_primitive_key()
{
return Traits<T>::PRIMITIVE ? primitive_key::yes : primitive_key::no;
return detail::is_primitive<T>::value ? primitive_key::yes : primitive_key::no;
}

template <class T>
constexpr primitive_val is_primitive_val()
{
return Traits<T>::PRIMITIVE ? primitive_val::yes : primitive_val::no;
return detail::is_primitive<T>::value ? primitive_val::yes : primitive_val::no;
}

template <class KeyT>
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/thread/thread_load.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE T
ThreadLoad(const T* ptr, detail::int_constant_t<LOAD_VOLATILE> /*modifier*/, ::cuda::std::true_type /*is_pointer*/)
{
return ThreadLoadVolatilePointer(ptr, ::cuda::std::bool_constant<Traits<T>::PRIMITIVE>());
return ThreadLoadVolatilePointer(ptr, ::cuda::std::bool_constant<detail::is_primitive<T>::value>());
}

/**
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/thread/thread_store.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStore(T* ptr, T val, detail::int_constant_t<STORE_VOLATILE> /*modifier*/, ::cuda::std::true_type /*is_pointer*/)
{
ThreadStoreVolatilePtr(ptr, val, ::cuda::std::bool_constant<Traits<T>::PRIMITIVE>());
ThreadStoreVolatilePtr(ptr, val, ::cuda::std::bool_constant<detail::is_primitive<T>::value>());
}

/**
Expand Down
Loading

0 comments on commit 23c395b

Please sign in to comment.