39
39
namespace cutlass ::gemm::collective {
40
40
41
41
namespace {
42
+ // TODO(codeplay): generic selection methods are overcomplicated
43
+
44
+ // Generic way to pick the number of subgroups along the M dim
45
+ // If tuning for a specific case, create a SubgroupTilingMap specialization
46
+ template <typename LayoutA, class TileShape_MNK >
47
+ constexpr inline int calculate_sgs_in_M () {
48
+ constexpr int tile_M = get<0 >(TileShape_MNK{});
49
+ if constexpr (cute::is_same_v<LayoutA, cutlass::layout::RowMajor>) {
50
+ // Non-transpose load can be size 1, 2, 4, 8, 16, or 32 in the M dim (for bf16),
51
+ // but we are only supporting 8, 16 and 32 so far.
52
+ for (auto atom_m : {32 ,16 ,8 }) {
53
+ auto atoms_in_m = tile_M / atom_m;
54
+ for (auto atoms : {8 ,4 ,2 }) {
55
+ if (atoms_in_m >= atoms) {
56
+ return atoms;
57
+ }
58
+ }
59
+ }
60
+ return 1 ;
61
+ } else {
62
+ // Transpose loads are always size 16 in the M dim (for bf16).
63
+ static_assert (tile_M / 16 > 0 and tile_M % 16 == 0 , " Invalid Tile size in M dim" );
64
+ return tile_M / 16 ;
65
+ }
66
+ }
67
+
68
+ // Generic way to pick a copy atom for A
69
+ // If tuning for a specific case, create a SubgroupTilingMap specialization
42
70
template <typename LayoutA, class TileShape_MNK , int sgs_M>
43
71
inline auto pick_load_atom_for_A () {
44
72
if constexpr (cute::is_same_v<LayoutA, cutlass::layout::RowMajor>) {
@@ -60,6 +88,8 @@ inline auto pick_load_atom_for_A() {
60
88
}
61
89
}
62
90
91
+ // Generic way to pick a copy atom for B
92
+ // If tuning for a specific case, create a SubgroupTilingMap specialization
63
93
template <typename LayoutB, class TileShape_MNK , int sgs_N>
64
94
inline auto pick_load_atom_for_B () {
65
95
if constexpr (cute::is_same_v<LayoutB, cutlass::layout::RowMajor>) {
@@ -76,28 +106,6 @@ inline auto pick_load_atom_for_B() {
76
106
}
77
107
}
78
108
79
- template <typename LayoutA, class TileShape_MNK >
80
- constexpr inline int calculate_sgs_in_M () {
81
- constexpr int tile_M = get<0 >(TileShape_MNK{});
82
- if constexpr (cute::is_same_v<LayoutA, cutlass::layout::RowMajor>) {
83
- // Non-transpose load can be size 1, 2, 4, 8, 16, or 32 in the M dim (for bf16),
84
- // but we are only supporting 8, 16 and 32 so far.
85
- for (auto atom_m : {32 ,16 ,8 }) {
86
- auto atoms_in_m = tile_M / atom_m;
87
- for (auto atoms : {8 ,4 ,2 }) {
88
- if (atoms_in_m >= atoms) {
89
- return atoms;
90
- }
91
- }
92
- }
93
- return 1 ;
94
- } else {
95
- // Transpose loads are always size 16 in the M dim (for bf16).
96
- static_assert (tile_M / 16 > 0 and tile_M % 16 == 0 , " Invalid Tile size in M dim" );
97
- return tile_M / 16 ;
98
- }
99
- }
100
-
101
109
// Lookup table for subgroup layout
102
110
// This is the default case
103
111
template <typename TileShape, typename LayoutA, typename LayoutB>
@@ -115,7 +123,6 @@ struct SubgroupTilingMap {
115
123
using sgs_N = Int<std::min(tile_N/atom_N, sgs_total/sgs_M::value)>;
116
124
using GmemTiledCopyA = decltype(pick_load_atom_for_A<LayoutA, TileShape, sgs_M{}>());
117
125
using GmemTiledCopyB = decltype(pick_load_atom_for_B<LayoutB, TileShape, sgs_N{}>());
118
-
119
126
};
120
127
121
128
template <>
@@ -222,16 +229,7 @@ struct CollectiveBuilder<
222
229
XE_8x16x16_F32BF16BF16F32_TT,
223
230
XE_8x16x16_F32F16F16F32_TT>>;
224
231
225
- // We have too many subgroups, we can have at most 32, but only 8 are needed for 8x128 values (8x16 mma)
226
232
// Prepare Template arguments required of CollectiveMainLoop
227
- static constexpr auto tile_M = get<0 >(TileShape_MNK{});
228
- static constexpr auto tile_N = get<1 >(TileShape_MNK{});
229
- static constexpr auto tile_K = get<2 >(TileShape_MNK{});
230
-
231
- // number of subgroups in a dim is at most (values in a dim)/(atom size in a dim)
232
- using atom_mnk = typename MMAAtom::Shape_MNK;
233
- using max_subgroups = decltype(take<0 ,2 >(shape_div(TileShape_MNK{}, atom_mnk{}))); // M, N
234
-
235
233
using SgTilingMap = SubgroupTilingMap<TileShape_MNK, GmemLayoutATag, GmemLayoutBTag>;
236
234
using sgs_M = typename SgTilingMap::sgs_M;
237
235
using sgs_N = typename SgTilingMap::sgs_N;
@@ -262,7 +260,6 @@ struct CollectiveBuilder<
262
260
using StrideA = cutlass::gemm::TagToStrideA_t<std::conditional_t <IsGroup, GmemLayoutATag*, GmemLayoutATag>>;
263
261
using StrideB = cutlass::gemm::TagToStrideB_t<std::conditional_t <IsGroup, GmemLayoutBTag*, GmemLayoutBTag>>;
264
262
265
-
266
263
using CollectiveOp = cutlass::gemm::collective::CollectiveMma<
267
264
DispatchPolicy,
268
265
TileShape_MNK,
0 commit comments