diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index e92d184210..b59e2461eb 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -427,14 +427,16 @@ CUTE_HOST_DEVICE constexpr auto make_fragment_layout(TiledCopy &tiled_copy, Int mma_atom_iters_in_copy_N = copy_size_N / mma_atom_size_N; Int copy_iters_M = total_mma_atom_iters_M / mma_atom_iters_in_copy_M; Int copy_iters_N = total_mma_atom_iters_N / mma_atom_iters_in_copy_N; + auto order = std::conditional_t, Step<_2, _4>, Step<_3, _5>>, Step, Step<_3, _5>, Step<_2, _4>>>{}; - - return make_ordered_layout(make_shape(mma_atom_shape_2d, - make_shape(mma_atom_iters_in_copy_M, copy_iters_M), - make_shape(mma_atom_iters_in_copy_N, copy_iters_N)), - order); + auto res = make_ordered_layout(make_shape(mma_atom_shape_2d, + make_shape(mma_atom_iters_in_copy_M, copy_iters_M), + make_shape(mma_atom_iters_in_copy_N, copy_iters_N)), + order); + static_assert(size(res) > 0, "Error in make_fragment_layout(), tile size might be smaller than copy atom"); + return res; }; // clang-format off @@ -1655,8 +1657,8 @@ struct Copy_Traits_ : XE_2D_LD_Unpack { using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; + using SrcLayout = Layout>, + Stride< _0, Stride< _1,_16>>>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_128,Stride< _1,_16>>>; diff --git a/include/cutlass/gemm/collective/builders/xe_mma_builder.inl b/include/cutlass/gemm/collective/builders/xe_mma_builder.inl index 5acb3939b4..aa81aeb6d0 100644 --- a/include/cutlass/gemm/collective/builders/xe_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/xe_mma_builder.inl @@ -41,6 +41,80 @@ namespace cutlass::gemm::collective { // Intel PVC 3 stage pipeline, using prefetch // Also the auto builder +template +constexpr auto get_num_atoms(T_m tile_m, T_n tile_n){ + constexpr auto atom_m = get<0>(typename MMAAtom::Shape_MNK{}); + constexpr auto atom_n = get<1>(typename MMAAtom::Shape_MNK{}); + // try to create the biggest number of atoms possible, up to 32, trying to fit the most, up to 8 in m dimension + auto atoms_m_tmp = cute::min(tile_m / atom_m, _8{}); // at most 8 + auto atoms_n = cute::min(tile_n / atom_n, _32{} / atoms_m_tmp); // at most however many are not in m out of 32 + auto atoms_m = cute::min(tile_m / atom_m, _32{} / atoms_n); // at most however many are not in n out of 32 + return make_shape(atoms_m, atoms_n); +} + +template +constexpr auto select_copy_atom_16b(T_m tile_m, T_n tile_n){ + #define RETURN_ATOM(WIDTH, HEIGHT, LETTER) \ + return XE_2D_U16x##WIDTH##x##HEIGHT##_LD_##LETTER {}; + + if constexpr(is_t){ + // tile_m and tile_n have swapped role in case of _T + static_assert(tile_n % 16 == 0 && "Invalid tile_m"); + if constexpr(tile_m == 8){ + RETURN_ATOM(16, 8, T) + } else if constexpr(tile_m % 16 == 0){ + RETURN_ATOM(16, 16, T) + } else{ + static_assert(dependent_false && "Invalid tile_n"); + } + } else if constexpr(is_v){ + #define SELECT_HEIGHT_V(WIDTH) \ + if constexpr(tile_n == 16){ \ + RETURN_ATOM(WIDTH, 16, V) \ + } else if constexpr(tile_n % 32 == 0){ \ + RETURN_ATOM(WIDTH, 32, V) \ + } else{ \ + static_assert(dependent_false && "Invalid tile_n"); \ + } + + if constexpr(tile_m == 16){ + SELECT_HEIGHT_V(16) + } else if constexpr(tile_m % 32 == 0){ + SELECT_HEIGHT_V(32) + } else{ + static_assert(dependent_false && "Invalid tile_m"); + } + #undef SELECT_HEIGHT_V + } else{ // _N + #define SELECT_WIDTH_N(HEIGHT) \ + if constexpr(tile_m == 1){ \ + RETURN_ATOM(1, HEIGHT, N) \ + } else if constexpr(tile_m == 2){ \ + RETURN_ATOM(2, HEIGHT, N) \ + } else if constexpr(tile_m == 4){ \ + RETURN_ATOM(4, HEIGHT, N) \ + } else if constexpr(tile_m == 8){ \ + RETURN_ATOM(8, HEIGHT, N) \ + } else if constexpr(tile_m == 16){ \ + RETURN_ATOM(16, HEIGHT, N) \ + } else if constexpr(tile_m % 32 == 0){ \ + RETURN_ATOM(32, HEIGHT, N) \ + } else { \ + static_assert(dependent_false && "Invalid tile_m"); \ + } + + if constexpr(tile_n == 16){ + SELECT_WIDTH_N(16) + } else if constexpr(tile_n % 32 == 0){ + SELECT_WIDTH_N(32) + } else { + static_assert(dependent_false && "Invalid tile_n"); + } + #undef SELECT_WIDTH_N + } + #undef RETURN_ATOM +} + template < class ElementA, class GmemLayoutATag, @@ -85,9 +159,13 @@ struct CollectiveBuilder< XE_8x16x16_F32BF16BF16F32_TT, XE_8x16x16_F32F16F16F32_TT>>; - // Prepare Template arguments required of CollectiveMainLoop - using atoms_M = _8; - using atoms_N = _4; + static constexpr auto tile_M = get<0>(TileShape_MNK{}); + static constexpr auto tile_N = get<1>(TileShape_MNK{}); + static constexpr auto tile_K = get<2>(TileShape_MNK{}); + + static constexpr auto n_atoms = get_num_atoms(tile_M, tile_N); + using atoms_M = decltype(get<0>(n_atoms)); + using atoms_N = decltype(get<1>(n_atoms)); using TiledMma = typename TiledMMAHelper, @@ -101,19 +179,8 @@ struct CollectiveBuilder< cutlass::gemm::MainloopIntelPVCGroup, cutlass::gemm::MainloopIntelPVC>; - static constexpr auto tile_M = get<0>(TileShape_MNK{}); - static constexpr auto tile_N = get<1>(TileShape_MNK{}); - static constexpr auto tile_K = get<2>(TileShape_MNK{}); - using GmemTiledCopyA = std::conditional_t, - std::conditional_t< tile_M/atoms_M{}>=_32{} && tile_K>=_32{}, - XE_2D_U16x32x32_LD_N, - XE_2D_U16x16x16_LD_N>, - XE_2D_U16x16x16_LD_T>; - using GmemTiledCopyB = std::conditional_t, - std::conditional_t< tile_N/atoms_N{}>=_32{} && tile_K>=_32{}, - XE_2D_U16x32x32_LD_V, - XE_2D_U16x16x16_LD_V>, - XE_2D_U16x16x16_LD_T>; + using GmemTiledCopyA = decltype(select_copy_atom_16b, false>(tile_M/atoms_M{}, tile_K)); + using GmemTiledCopyB = decltype(select_copy_atom_16b, true>(tile_K, tile_N/atoms_N{})); // PVC pipeline does not use shared memory using SmemLayoutAtomA = void; diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index 42725fe1a5..d237e34093 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -199,7 +199,7 @@ struct CollectiveMma, TileShape_, ElementA_, PRINT(tCrA); PRINT(tArA); - PRINT(mainloop.copy_A); + PRINT(mainloop.tiled_copy_a); print("======================= B: \n"); PRINT(tCgB); @@ -207,7 +207,7 @@ struct CollectiveMma, TileShape_, ElementA_, PRINT(tCrB); PRINT(tBrB); - PRINT(mainloop.copy_B); + PRINT(mainloop.tiled_copy_b); } #undef PRINT #endif