Skip to content

Commit bf21e0f

Browse files
committed
style fixes
1 parent 4be36f2 commit bf21e0f

File tree

2 files changed

+32
-34
lines changed

2 files changed

+32
-34
lines changed

benchmarks/pvc/gemm_configuration.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ struct GemmConfiguration<
8080
float, LayoutC,
8181
float, TileShape,
8282
TileScheduler> {
83-
using KernelScheduleType =std::conditional_t<TileScheduler == Scheduler::Gemm, cutlass::gemm::KernelPVC, cutlass::gemm::KernelPVCCooperative>;
83+
using KernelScheduleType = std::conditional_t<TileScheduler == Scheduler::Gemm,
84+
cutlass::gemm::KernelPVC, cutlass::gemm::KernelPVCCooperative>;
8485

8586

8687
static_assert(std::is_same_v<LayoutC, cutlass::layout::RowMajor>, "Column Major LayoutC unsupported in collective builder");

include/cutlass/gemm/collective/builders/xe_mma_builder.inl

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,34 @@
3939
namespace cutlass::gemm::collective {
4040

4141
namespace {
42+
// TODO(codeplay): generic selection methods are overcomplicated
43+
44+
// Generic way to pick the number of subgroups along the M dim
45+
// If tuning for a specific case, create a SubgroupTilingMap specialization
46+
template <typename LayoutA, class TileShape_MNK>
47+
constexpr inline int calculate_sgs_in_M() {
48+
constexpr int tile_M = get<0>(TileShape_MNK{});
49+
if constexpr (cute::is_same_v<LayoutA, cutlass::layout::RowMajor>) {
50+
// Non-transpose load can be size 1, 2, 4, 8, 16, or 32 in the M dim (for bf16),
51+
// but we are only supporting 8, 16 and 32 so far.
52+
for (auto atom_m : {32,16,8}) {
53+
auto atoms_in_m = tile_M / atom_m;
54+
for (auto atoms : {8,4,2}) {
55+
if (atoms_in_m >= atoms) {
56+
return atoms;
57+
}
58+
}
59+
}
60+
return 1;
61+
} else {
62+
// Transpose loads are always size 16 in the M dim (for bf16).
63+
static_assert(tile_M / 16 > 0 and tile_M % 16 == 0, "Invalid Tile size in M dim");
64+
return tile_M / 16;
65+
}
66+
}
67+
68+
// Generic way to pick a copy atom for A
69+
// If tuning for a specific case, create a SubgroupTilingMap specialization
4270
template <typename LayoutA, class TileShape_MNK, int sgs_M>
4371
inline auto pick_load_atom_for_A() {
4472
if constexpr (cute::is_same_v<LayoutA, cutlass::layout::RowMajor>) {
@@ -60,6 +88,8 @@ inline auto pick_load_atom_for_A() {
6088
}
6189
}
6290

91+
// Generic way to pick a copy atom for B
92+
// If tuning for a specific case, create a SubgroupTilingMap specialization
6393
template <typename LayoutB, class TileShape_MNK, int sgs_N>
6494
inline auto pick_load_atom_for_B() {
6595
if constexpr (cute::is_same_v<LayoutB, cutlass::layout::RowMajor>) {
@@ -76,28 +106,6 @@ inline auto pick_load_atom_for_B() {
76106
}
77107
}
78108

79-
template <typename LayoutA, class TileShape_MNK>
80-
constexpr inline int calculate_sgs_in_M() {
81-
constexpr int tile_M = get<0>(TileShape_MNK{});
82-
if constexpr (cute::is_same_v<LayoutA, cutlass::layout::RowMajor>) {
83-
// Non-transpose load can be size 1, 2, 4, 8, 16, or 32 in the M dim (for bf16),
84-
// but we are only supporting 8, 16 and 32 so far.
85-
for (auto atom_m : {32,16,8}) {
86-
auto atoms_in_m = tile_M / atom_m;
87-
for (auto atoms : {8,4,2}) {
88-
if (atoms_in_m >= atoms) {
89-
return atoms;
90-
}
91-
}
92-
}
93-
return 1;
94-
} else {
95-
// Transpose loads are always size 16 in the M dim (for bf16).
96-
static_assert(tile_M / 16 > 0 and tile_M % 16 == 0, "Invalid Tile size in M dim");
97-
return tile_M / 16;
98-
}
99-
}
100-
101109
// Lookup table for subgroup layout
102110
// This is the default case
103111
template <typename TileShape, typename LayoutA, typename LayoutB>
@@ -115,7 +123,6 @@ struct SubgroupTilingMap {
115123
using sgs_N = Int<std::min(tile_N/atom_N, sgs_total/sgs_M::value)>;
116124
using GmemTiledCopyA = decltype(pick_load_atom_for_A<LayoutA, TileShape, sgs_M{}>());
117125
using GmemTiledCopyB = decltype(pick_load_atom_for_B<LayoutB, TileShape, sgs_N{}>());
118-
119126
};
120127

121128
template <>
@@ -222,16 +229,7 @@ struct CollectiveBuilder<
222229
XE_8x16x16_F32BF16BF16F32_TT,
223230
XE_8x16x16_F32F16F16F32_TT>>;
224231

225-
// We have too many subgroups, we can have at most 32, but only 8 are needed for 8x128 values (8x16 mma)
226232
// Prepare Template arguments required of CollectiveMainLoop
227-
static constexpr auto tile_M = get<0>(TileShape_MNK{});
228-
static constexpr auto tile_N = get<1>(TileShape_MNK{});
229-
static constexpr auto tile_K = get<2>(TileShape_MNK{});
230-
231-
// number of subgroups in a dim is at most (values in a dim)/(atom size in a dim)
232-
using atom_mnk = typename MMAAtom::Shape_MNK;
233-
using max_subgroups = decltype(take<0,2>(shape_div(TileShape_MNK{}, atom_mnk{}))); // M, N
234-
235233
using SgTilingMap = SubgroupTilingMap<TileShape_MNK, GmemLayoutATag, GmemLayoutBTag>;
236234
using sgs_M = typename SgTilingMap::sgs_M;
237235
using sgs_N = typename SgTilingMap::sgs_N;
@@ -262,7 +260,6 @@ struct CollectiveBuilder<
262260
using StrideA = cutlass::gemm::TagToStrideA_t<std::conditional_t<IsGroup, GmemLayoutATag*, GmemLayoutATag>>;
263261
using StrideB = cutlass::gemm::TagToStrideB_t<std::conditional_t<IsGroup, GmemLayoutBTag*, GmemLayoutBTag>>;
264262

265-
266263
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
267264
DispatchPolicy,
268265
TileShape_MNK,

0 commit comments

Comments
 (0)