From 206e7cb7287b2f0cbd5d74a2cadee355470c449d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Gr=C3=BCtzmacher?= Date: Wed, 10 Jul 2024 17:52:20 +0200 Subject: [PATCH] Add additional tests for the communicator group --- cuda/test/components/cooperative_groups.cu | 37 +++++++++++++++++ .../test/components/cooperative_groups.dp.cpp | 40 +++++++++++++++++++ .../components/cooperative_groups.hip.cpp | 36 +++++++++++++++++ 3 files changed, 113 insertions(+) diff --git a/cuda/test/components/cooperative_groups.cu b/cuda/test/components/cooperative_groups.cu index df3cef86bb8..077b0121fbd 100644 --- a/cuda/test/components/cooperative_groups.cu +++ b/cuda/test/components/cooperative_groups.cu @@ -223,4 +223,41 @@ TEST_F(CooperativeGroups, SubwarpBallot) { test(cg_subwarp_ballot); } TEST_F(CooperativeGroups, SubwarpBallot2) { test_subwarp(cg_subwarp_ballot); } +__global__ void cg_communicator_categorization(bool*) +{ + auto this_block = group::this_thread_block(); + auto tiled_partition = + group::tiled_partition(this_block); + auto subwarp_partition = group::tiled_partition(this_block); + + using not_group = int; + using this_block_t = decltype(this_block); + using tiled_partition_t = decltype(tiled_partition); + using subwarp_partition_t = decltype(subwarp_partition); + + static_assert(!group::is_group::value && + group::is_group::value && + group::is_group::value && + group::is_group::value, + "Group check doesn't work."); + static_assert( + !group::is_synchronizable_group::value && + group::is_synchronizable_group::value && + group::is_synchronizable_group::value && + group::is_synchronizable_group::value, + "Synchronizable group check doesn't work."); + static_assert( + !group::is_communicator_group::value && + !group::is_communicator_group::value && + group::is_communicator_group::value && + group::is_communicator_group::value, + "Communicator group check doesn't work."); +} + +TEST_F(CooperativeGroups, CorrectCategorization) +{ + test(cg_communicator_categorization); +} + + } // namespace diff --git a/dpcpp/test/components/cooperative_groups.dp.cpp b/dpcpp/test/components/cooperative_groups.dp.cpp index 27e14b62d2d..5e7ed495504 100644 --- a/dpcpp/test/components/cooperative_groups.dp.cpp +++ b/dpcpp/test/components/cooperative_groups.dp.cpp @@ -198,6 +198,46 @@ GKO_ENABLE_DEFAULT_CONFIG_CALL(cg_ballot_call, cg_ballot, default_config_list) TEST_P(CooperativeGroups, Ballot) { test_all_subgroup(cg_ballot_call); } +template +void cg_communicator_categorization(bool* s, sycl::nd_item<3> item_ct1) +{ + auto this_block = group::this_thread_block(item_ct1); + auto tiled_partition = + group::tiled_partition(this_block); + + using this_block_t = decltype(this_block); + using tiled_partition_t = decltype(tiled_partition); + + static_assert(!group::is_group::value && + group::is_group::value && + group::is_group::value, + "Group check doesn't work."); + static_assert( + !group::is_synchronizable_group::value && + group::is_synchronizable_group::value && + group::is_synchronizable_group::value, + "Synchronizable group check doesn't work."); + static_assert( + !group::is_communicator_group::value && + !group::is_communicator_group::value && + group::is_communicator_group::value, + "Communicator group check doesn't work."); +} + +GKO_ENABLE_DEFAULT_HOST_CONFIG_TYPE(cg_communicator_categorization, + cg_communicator_categorization) +GKO_ENABLE_IMPLEMENTATION_CONFIG_SELECTION_TOTYPE( + cg_communicator_categorization, cg_communicator_categorization, DCFG_1D) +GKO_ENABLE_DEFAULT_CONFIG_CALL(cg_communicator_categorization_call, + cg_communicator_categorization, + default_config_list) + +TEST_P(CooperativeGroups, CorrectCategorization) +{ + test_all_subgroup(cg_communicator_categorization_call); +} + + INSTANTIATE_TEST_SUITE_P(DifferentSubgroup, CooperativeGroups, testing::Values(4, 8, 16, 32, 64), testing::PrintToStringParamName()); diff --git a/hip/test/components/cooperative_groups.hip.cpp b/hip/test/components/cooperative_groups.hip.cpp index 06a104a8879..c36965acbf8 100644 --- a/hip/test/components/cooperative_groups.hip.cpp +++ b/hip/test/components/cooperative_groups.hip.cpp @@ -242,6 +242,42 @@ TEST_F(CooperativeGroups, SubwarpBallot) { test(cg_subwarp_ballot); } TEST_F(CooperativeGroups, SubwarpBallot2) { test_subwarp(cg_subwarp_ballot); } +__global__ void cg_communicator_categorization(bool*) +{ + auto this_block = group::this_thread_block(); + auto tiled_partition = + group::tiled_partition(this_block); + auto subwarp_partition = group::tiled_partition(this_block); + + using this_block_t = decltype(this_block); + using tiled_partition_t = decltype(tiled_partition); + using subwarp_partition_t = decltype(subwarp_partition); + + static_assert(!group::is_group::value && + group::is_group::value && + group::is_group::value && + group::is_group::value, + "Group check doesn't work."); + static_assert( + !group::is_synchronizable_group::value && + group::is_synchronizable_group::value && + group::is_synchronizable_group::value && + group::is_synchronizable_group::value, + "Synchronizable group check doesn't work."); + static_assert( + !group::is_communicator_group::value && + !group::is_communicator_group::value && + group::is_communicator_group::value && + group::is_communicator_group::value, + "Communicator group check doesn't work."); +} + +TEST_F(CooperativeGroups, CorrectCategorization) +{ + test(cg_communicator_categorization); +} + + template __global__ void cg_shuffle_sum(const int num, ValueType* __restrict__ value) {