Skip to content

Commit

Permalink
Add additional tests for the communicator group
Browse files Browse the repository at this point in the history
  • Loading branch information
thoasm committed Jul 10, 2024
1 parent caa373d commit 48f1ecc
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
37 changes: 37 additions & 0 deletions cuda/test/components/cooperative_groups.cu
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,41 @@ TEST_F(CooperativeGroups, SubwarpBallot) { test(cg_subwarp_ballot); }
TEST_F(CooperativeGroups, SubwarpBallot2) { test_subwarp(cg_subwarp_ballot); }
__global__ void cg_communicator_categorization(bool*)
{
auto this_block = group::this_thread_block();
auto tiled_partition =
group::tiled_partition<config::warp_size>(this_block);
auto subwarp_partition = group::tiled_partition<subwarp_size>(this_block);
using not_group = int;
using this_block_t = decltype(this_block);
using tiled_partition_t = decltype(tiled_partition);
using subwarp_partition_t = decltype(subwarp_partition);
static_assert(!group::is_group<not_group>::value &&
group::is_group<this_block_t>::value &&
group::is_group<tiled_partition_t>::value &&
group::is_group<subwarp_partition_t>::value,
"Group check doesn't work.");
static_assert(
!group::is_synchronizable_group<not_group>::value &&
group::is_synchronizable_group<this_block_t>::value &&
group::is_synchronizable_group<tiled_partition_t>::value &&
group::is_synchronizable_group<subwarp_partition_t>::value,
"Synchronizable group check doesn't work.");
static_assert(
!group::is_communicator_group<not_group>::value &&
!group::is_communicator_group<this_block_t>::value &&
group::is_communicator_group<tiled_partition_t>::value &&
group::is_communicator_group<subwarp_partition_t>::value,
"Communicator group check doesn't work.");
}
TEST_F(CooperativeGroups, CorrectCategorization)
{
test(cg_communicator_categorization);
}
} // namespace
41 changes: 41 additions & 0 deletions dpcpp/test/components/cooperative_groups.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,47 @@ GKO_ENABLE_DEFAULT_CONFIG_CALL(cg_ballot_call, cg_ballot, default_config_list)
TEST_P(CooperativeGroups, Ballot) { test_all_subgroup(cg_ballot_call<bool*>); }


template <typename cfg>
void cg_communicator_categorization(bool* s, sycl::nd_item<3> item_ct1)
{
auto this_block = group::this_thread_block(item_ct1);
auto tiled_partition =
group::tiled_partition<cfg::subgroup_size>(this_block);

using not_group = int;
using this_block_t = decltype(this_block);
using tiled_partition_t = decltype(tiled_partition);

static_assert(!group::is_group<not_group>::value &&
group::is_group<this_block_t>::value &&
group::is_group<tiled_partition_t>::value,
"Group check doesn't work.");
static_assert(
!group::is_synchronizable_group<not_group>::value &&
group::is_synchronizable_group<this_block_t>::value &&
group::is_synchronizable_group<tiled_partition_t>::value,
"Synchronizable group check doesn't work.");
static_assert(
!group::is_communicator_group<not_group>::value &&
!group::is_communicator_group<this_block_t>::value &&
group::is_communicator_group<tiled_partition_t>::value,
"Communicator group check doesn't work.");
}

GKO_ENABLE_DEFAULT_HOST_CONFIG_TYPE(cg_communicator_categorization,
cg_communicator_categorization)
GKO_ENABLE_IMPLEMENTATION_CONFIG_SELECTION_TOTYPE(
cg_communicator_categorization, cg_communicator_categorization, DCFG_1D)
GKO_ENABLE_DEFAULT_CONFIG_CALL(cg_communicator_categorization_call,
cg_communicator_categorization,
default_config_list)

TEST_P(CooperativeGroups, CorrectCategorization)
{
test_all_subgroup(cg_communicator_categorization_call<bool*>);
}


INSTANTIATE_TEST_SUITE_P(DifferentSubgroup, CooperativeGroups,
testing::Values(4, 8, 16, 32, 64),
testing::PrintToStringParamName());
Expand Down
37 changes: 37 additions & 0 deletions hip/test/components/cooperative_groups.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,43 @@ TEST_F(CooperativeGroups, SubwarpBallot) { test(cg_subwarp_ballot); }
TEST_F(CooperativeGroups, SubwarpBallot2) { test_subwarp(cg_subwarp_ballot); }


__global__ void cg_communicator_categorization(bool*)
{
auto this_block = group::this_thread_block();
auto tiled_partition =
group::tiled_partition<config::warp_size>(this_block);
auto subwarp_partition = group::tiled_partition<subwarp_size>(this_block);

using not_group = int;
using this_block_t = decltype(this_block);
using tiled_partition_t = decltype(tiled_partition);
using subwarp_partition_t = decltype(subwarp_partition);

static_assert(!group::is_group<not_group>::value &&
group::is_group<this_block_t>::value &&
group::is_group<tiled_partition_t>::value &&
group::is_group<subwarp_partition_t>::value,
"Group check doesn't work.");
static_assert(
!group::is_synchronizable_group<not_group>::value &&
group::is_synchronizable_group<this_block_t>::value &&
group::is_synchronizable_group<tiled_partition_t>::value &&
group::is_synchronizable_group<subwarp_partition_t>::value,
"Synchronizable group check doesn't work.");
static_assert(
!group::is_communicator_group<not_group>::value &&
!group::is_communicator_group<this_block_t>::value &&
group::is_communicator_group<tiled_partition_t>::value &&
group::is_communicator_group<subwarp_partition_t>::value,
"Communicator group check doesn't work.");
}

TEST_F(CooperativeGroups, CorrectCategorization)
{
test(cg_communicator_categorization);
}


template <typename ValueType>
__global__ void cg_shuffle_sum(const int num, ValueType* __restrict__ value)
{
Expand Down

0 comments on commit 48f1ecc

Please sign in to comment.