Skip to content

generalize collective builder across more tile shapes #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,14 +427,16 @@ CUTE_HOST_DEVICE constexpr auto make_fragment_layout(TiledCopy &tiled_copy,
Int mma_atom_iters_in_copy_N = copy_size_N / mma_atom_size_N;
Int copy_iters_M = total_mma_atom_iters_M / mma_atom_iters_in_copy_M;
Int copy_iters_N = total_mma_atom_iters_N / mma_atom_iters_in_copy_N;

auto order = std::conditional_t<TiledCopy::is_convention_MN,
Step<Step<_0, _1>, Step<_2, _4>, Step<_3, _5>>,
Step<Step<_0, _1>, Step<_3, _5>, Step<_2, _4>>>{};

return make_ordered_layout(make_shape(mma_atom_shape_2d,
make_shape(mma_atom_iters_in_copy_M, copy_iters_M),
make_shape(mma_atom_iters_in_copy_N, copy_iters_N)),
order);
auto res = make_ordered_layout(make_shape(mma_atom_shape_2d,
make_shape(mma_atom_iters_in_copy_M, copy_iters_M),
make_shape(mma_atom_iters_in_copy_N, copy_iters_N)),
order);
static_assert(size(res) > 0, "Error in make_fragment_layout(), tile size might be smaller than copy atom");
return res;
};

// clang-format off
Expand Down Expand Up @@ -1655,8 +1657,8 @@ struct Copy_Traits_<XE_2D_U16x16x8_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_U16x16x8_LD_T, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,_16>,
Stride< _0, _1>>;
using SrcLayout = Layout<Shape <_16, Shape <_16, _8>>,
Stride< _0, Stride< _1,_16>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape < _16,Shape <_16, _8>>,
Stride<_128,Stride< _1,_16>>>;
Expand Down
99 changes: 83 additions & 16 deletions include/cutlass/gemm/collective/builders/xe_mma_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,80 @@ namespace cutlass::gemm::collective {
// Intel PVC 3 stage pipeline, using prefetch
// Also the auto builder

template<typename MMAAtom, typename T_m, typename T_n>
constexpr auto get_num_atoms(T_m tile_m, T_n tile_n){
constexpr auto atom_m = get<0>(typename MMAAtom::Shape_MNK{});
constexpr auto atom_n = get<1>(typename MMAAtom::Shape_MNK{});
// try to create the biggest number of atoms possible, up to 32, trying to fit the most, up to 8 in m dimension
auto atoms_m_tmp = cute::min(tile_m / atom_m, _8{}); // at most 8
auto atoms_n = cute::min(tile_n / atom_n, _32{} / atoms_m_tmp); // at most however many are not in m out of 32
auto atoms_m = cute::min(tile_m / atom_m, _32{} / atoms_n); // at most however many are not in n out of 32
return make_shape(atoms_m, atoms_n);
}

template<bool is_t, bool is_v, typename T_m, typename T_n>
constexpr auto select_copy_atom_16b(T_m tile_m, T_n tile_n){
#define RETURN_ATOM(WIDTH, HEIGHT, LETTER) \
return XE_2D_U16x##WIDTH##x##HEIGHT##_LD_##LETTER {};

if constexpr(is_t){
// tile_m and tile_n have swapped role in case of _T
static_assert(tile_n % 16 == 0 && "Invalid tile_m");
if constexpr(tile_m == 8){
RETURN_ATOM(16, 8, T)
} else if constexpr(tile_m % 16 == 0){
RETURN_ATOM(16, 16, T)
} else{
static_assert(dependent_false<T_m> && "Invalid tile_n");
}
} else if constexpr(is_v){
#define SELECT_HEIGHT_V(WIDTH) \
if constexpr(tile_n == 16){ \
RETURN_ATOM(WIDTH, 16, V) \
} else if constexpr(tile_n % 32 == 0){ \
RETURN_ATOM(WIDTH, 32, V) \
} else{ \
static_assert(dependent_false<T_n> && "Invalid tile_n"); \
}

if constexpr(tile_m == 16){
SELECT_HEIGHT_V(16)
} else if constexpr(tile_m % 32 == 0){
SELECT_HEIGHT_V(32)
} else{
static_assert(dependent_false<T_m> && "Invalid tile_m");
}
#undef SELECT_HEIGHT_V
} else{ // _N
#define SELECT_WIDTH_N(HEIGHT) \
if constexpr(tile_m == 1){ \
RETURN_ATOM(1, HEIGHT, N) \
} else if constexpr(tile_m == 2){ \
RETURN_ATOM(2, HEIGHT, N) \
} else if constexpr(tile_m == 4){ \
RETURN_ATOM(4, HEIGHT, N) \
} else if constexpr(tile_m == 8){ \
RETURN_ATOM(8, HEIGHT, N) \
} else if constexpr(tile_m == 16){ \
RETURN_ATOM(16, HEIGHT, N) \
} else if constexpr(tile_m % 32 == 0){ \
RETURN_ATOM(32, HEIGHT, N) \
} else { \
static_assert(dependent_false<T_m> && "Invalid tile_m"); \
}

if constexpr(tile_n == 16){
SELECT_WIDTH_N(16)
} else if constexpr(tile_n % 32 == 0){
SELECT_WIDTH_N(32)
} else {
static_assert(dependent_false<T_n> && "Invalid tile_n");
}
#undef SELECT_WIDTH_N
}
#undef RETURN_ATOM
}

template <
class ElementA,
class GmemLayoutATag,
Expand Down Expand Up @@ -85,9 +159,13 @@ struct CollectiveBuilder<
XE_8x16x16_F32BF16BF16F32_TT,
XE_8x16x16_F32F16F16F32_TT>>;

// Prepare Template arguments required of CollectiveMainLoop
using atoms_M = _8;
using atoms_N = _4;
static constexpr auto tile_M = get<0>(TileShape_MNK{});
static constexpr auto tile_N = get<1>(TileShape_MNK{});
static constexpr auto tile_K = get<2>(TileShape_MNK{});

static constexpr auto n_atoms = get_num_atoms<MMAAtom>(tile_M, tile_N);
using atoms_M = decltype(get<0>(n_atoms));
using atoms_N = decltype(get<1>(n_atoms));
using TiledMma =
typename TiledMMAHelper<MMAAtom,
Layout<TileShape_MNK>,
Expand All @@ -101,19 +179,8 @@ struct CollectiveBuilder<
cutlass::gemm::MainloopIntelPVCGroup<PipelineStages, KernelSchedule>,
cutlass::gemm::MainloopIntelPVC<PipelineStages, KernelSchedule>>;

static constexpr auto tile_M = get<0>(TileShape_MNK{});
static constexpr auto tile_N = get<1>(TileShape_MNK{});
static constexpr auto tile_K = get<2>(TileShape_MNK{});
using GmemTiledCopyA = std::conditional_t<cute::is_same_v<GmemLayoutATag, cutlass::layout::RowMajor>,
std::conditional_t< tile_M/atoms_M{}>=_32{} && tile_K>=_32{},
XE_2D_U16x32x32_LD_N,
XE_2D_U16x16x16_LD_N>,
XE_2D_U16x16x16_LD_T>;
using GmemTiledCopyB = std::conditional_t<cute::is_same_v<GmemLayoutBTag, cutlass::layout::RowMajor>,
std::conditional_t< tile_N/atoms_N{}>=_32{} && tile_K>=_32{},
XE_2D_U16x32x32_LD_V,
XE_2D_U16x16x16_LD_V>,
XE_2D_U16x16x16_LD_T>;
using GmemTiledCopyA = decltype(select_copy_atom_16b<cute::is_same_v<GmemLayoutATag, cutlass::layout::ColumnMajor>, false>(tile_M/atoms_M{}, tile_K));
using GmemTiledCopyB = decltype(select_copy_atom_16b<cute::is_same_v<GmemLayoutBTag, cutlass::layout::ColumnMajor>, true>(tile_K, tile_N/atoms_N{}));

// PVC pipeline does not use shared memory
using SmemLayoutAtomA = void;
Expand Down
4 changes: 2 additions & 2 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,15 @@ struct CollectiveMma<MainloopIntelPVC<Stages, Schedule>, TileShape_, ElementA_,

PRINT(tCrA);
PRINT(tArA);
PRINT(mainloop.copy_A);
PRINT(mainloop.tiled_copy_a);

print("======================= B: \n");
PRINT(tCgB);
PRINT(tBgB);

PRINT(tCrB);
PRINT(tBrB);
PRINT(mainloop.copy_B);
PRINT(mainloop.tiled_copy_b);
}
#undef PRINT
#endif
Expand Down
Loading