|
| 1 | +/*************************************************************************************************** |
| 2 | + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. |
| 3 | + * SPDX-License-Identifier: BSD-3-Clause |
| 4 | + * |
| 5 | + * Redistribution and use in source and binary forms, with or without |
| 6 | + * modification, are permitted provided that the following conditions are met: |
| 7 | + * |
| 8 | + * 1. Redistributions of source code must retain the above copyright notice, this |
| 9 | + * list of conditions and the following disclaimer. |
| 10 | + * |
| 11 | + * 2. Redistributions in binary form must reproduce the above copyright notice, |
| 12 | + * this list of conditions and the following disclaimer in the documentation |
| 13 | + * and/or other materials provided with the distribution. |
| 14 | + * |
| 15 | + * 3. Neither the name of the copyright holder nor the names of its |
| 16 | + * contributors may be used to endorse or promote products derived from |
| 17 | + * this software without specific prior written permission. |
| 18 | + * |
| 19 | + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 20 | + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 21 | + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| 22 | + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| 23 | + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| 24 | + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| 25 | + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| 26 | + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| 27 | + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 28 | + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 29 | + * |
| 30 | + **************************************************************************************************/ |
| 31 | +/*! \file |
| 32 | + \brief Functor performing elementwise operations used by epilogues. |
| 33 | +*/ |
| 34 | + |
| 35 | +#pragma once |
| 36 | + |
| 37 | +#include <sycl/sycl.hpp> |
| 38 | +#include "cutlass/cutlass.h" |
| 39 | +#include "cutlass/epilogue/dispatch_policy.hpp" |
| 40 | +#include "cutlass/epilogue/collective/collective_epilogue.hpp" |
| 41 | +#include "cutlass/epilogue/collective/detail.hpp" |
| 42 | +#include "cutlass/detail/layout.hpp" |
| 43 | + |
| 44 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
| 45 | + |
| 46 | +namespace cutlass { |
| 47 | +namespace flash_attention { |
| 48 | +namespace collective { |
| 49 | + |
| 50 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
| 51 | + |
| 52 | +template <class DispatchPolicy, class... Args> class CollectiveEpilogueAttention { |
| 53 | + static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization."); |
| 54 | +}; |
| 55 | + |
| 56 | +template <class CtaTileMNK_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_> |
| 57 | +class CollectiveEpilogueAttention<epilogue::IntelPVCEpilogue, CtaTileMNK_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> { |
| 58 | +public: |
| 59 | + // |
| 60 | + // Type Aliases |
| 61 | + // |
| 62 | + using DispatchPolicy = epilogue::IntelPVCEpilogue; |
| 63 | + using CtaTileMNK = CtaTileMNK_; |
| 64 | + using ElementO = ElementO_; |
| 65 | + using ElementAccumulator = ElementO_; |
| 66 | + using StrideO = StrideO_; |
| 67 | + using ElementLSE = ElementLSE_; |
| 68 | + using CopyOpO = CopyOpO_; |
| 69 | + |
| 70 | + using GmemTiledCopyO = CopyOpO; |
| 71 | + using ElementOutput = ElementO_; |
| 72 | + using ElementCompute = ElementO_; |
| 73 | + |
| 74 | + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; |
| 75 | + |
| 76 | + static_assert(cute::rank(CtaTileMNK{}) == 4, "CtaTileMNK must be rank-4: [CTA_M_Q, CTA_N_V, CTA_N_QK, CTA_K_QK]"); |
| 77 | + static_assert(cute::rank(StrideO{}) == 3, "StrideO must be rank-3: [seq_len_qo, head_size_vo, batch * num_heads]"); |
| 78 | + |
| 79 | + using CopyThreadShape = Shape<_1, Int<SubgroupSize>>; |
| 80 | + |
| 81 | + using traits_store_O = Copy_Traits<GmemTiledCopyO, StrideO>; |
| 82 | + using atom_load_O = Copy_Atom<traits_store_O, ElementO>; |
| 83 | + using val_layout_load_O = decltype(make_layout(shape_div(typename traits_store_O::BlockShape{}, CopyThreadShape{}))); |
| 84 | + using XE_Copy_O = decltype(make_tiled_copy(atom_load_O{}, Layout<CopyThreadShape>{}, val_layout_load_O{})); |
| 85 | + |
| 86 | +private: |
| 87 | + constexpr static bool is_destination_supported = not cute::is_void_v<ElementO>; |
| 88 | + |
| 89 | +public: |
| 90 | + using EmptyType = cute::tuple<>; |
| 91 | + |
| 92 | + struct TensorStorageImpl : cute::tuple<EmptyType, EmptyType> {}; |
| 93 | + |
| 94 | + struct SharedStorage { |
| 95 | + using TensorStorage = TensorStorageImpl; |
| 96 | + |
| 97 | + TensorStorage tensors; |
| 98 | + }; |
| 99 | + using TensorStorage = typename SharedStorage::TensorStorage; |
| 100 | + |
| 101 | + // Host side epilogue arguments |
| 102 | + struct Arguments { |
| 103 | + ElementO const *ptr_O; |
| 104 | + StrideO dO; |
| 105 | + }; |
| 106 | + |
| 107 | + // Device side epilogue params |
| 108 | + struct Params { |
| 109 | + XE_Copy_O xe_store_o; |
| 110 | + }; |
| 111 | + |
| 112 | + // |
| 113 | + // Methods |
| 114 | + // |
| 115 | + |
| 116 | + template <class ProblemShape> |
| 117 | + static constexpr Params to_underlying_arguments(ProblemShape const &problem_shape, Arguments const &args, |
| 118 | + [[maybe_unused]] void *workspace) { |
| 119 | + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; |
| 120 | + |
| 121 | + auto tensorO = make_tensor(make_gmem_ptr(static_cast<ElementO const*>(args.ptr_O)), |
| 122 | + make_layout(make_shape(seq_len_qo, head_size_vo, batch * num_heads_q), |
| 123 | + args.dO)); |
| 124 | + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; |
| 125 | + return { |
| 126 | + xe_store_o, |
| 127 | + }; |
| 128 | + } |
| 129 | + |
| 130 | + template <class ProblemShape> |
| 131 | + static size_t get_workspace_size(ProblemShape const &problem_shape, Arguments const &args) { |
| 132 | + return 0; |
| 133 | + } |
| 134 | + |
| 135 | + template <class ProblemShape> |
| 136 | + static cutlass::Status initialize_workspace(ProblemShape const &problem_shape, Arguments const &args, void *workspace, |
| 137 | + cudaStream_t stream, CudaHostAdapter *cuda_adapter = nullptr) { |
| 138 | + return Status::kSuccess; |
| 139 | + } |
| 140 | + |
| 141 | + template <class ProblemShape> |
| 142 | + CUTLASS_HOST_DEVICE static bool can_implement(ProblemShape const &problem_shape, |
| 143 | + [[maybe_unused]] Arguments const &args) { |
| 144 | + return true; |
| 145 | + } |
| 146 | + |
| 147 | + CUTLASS_HOST_DEVICE |
| 148 | + CollectiveEpilogueAttention(Params const ¶ms_, TensorStorage const &) : params(params_) {} |
| 149 | + |
| 150 | + template <class ProblemShape, class TileCoord, class FragOut, class FragMax, class FragSum, class TiledMma> |
| 151 | + CUTLASS_DEVICE void operator()(ProblemShape problem_shape, TileCoord tile_coord, FragOut &out, FragMax const &max, |
| 152 | + FragSum &sum, TiledMma tiled_mma, ElementCompute const &softmax_scale) { |
| 153 | + |
| 154 | + using namespace cute; |
| 155 | + |
| 156 | + using MmaAtomShape = typename TiledMma::AtomShape_MNK; |
| 157 | + using SubgroupTileShape = decltype(cute::shape_div(take<0, 3>(CtaTileMNK{}), take<1, 4>(typename TiledMma::ThrLayoutVMNK{}.shape()))); |
| 158 | + using FragsShape = decltype(cute::shape_div(take<0, 2>(SubgroupTileShape{}), take<0, 2>(MmaAtomShape()))); |
| 159 | + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v<tuple_element_t<2, ProblemShape>>; |
| 160 | + |
| 161 | + static constexpr int FragsM = get<0>(FragsShape{}); // A frags per sub_group |
| 162 | + static constexpr int FragsN = get<1>(FragsShape{}); // B frags per sub_group |
| 163 | + static constexpr int Vec = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; |
| 164 | + |
| 165 | + auto g = syclcompat::get_nd_item<1>().get_sub_group(); |
| 166 | + |
| 167 | + CUTLASS_PRAGMA_UNROLL |
| 168 | + for (int y = 0; y < FragsM; y++) { |
| 169 | + CUTLASS_PRAGMA_UNROLL |
| 170 | + for (int x = 0; x < Vec; x++) { |
| 171 | + int indx = y * Vec + x; |
| 172 | + auto cur_sum = reduce_over_group(g, sum(indx), sycl::plus<>()); |
| 173 | + auto cur_scale = (cur_sum == 0.f || cur_sum != cur_sum) ? 1.f : sycl::native::recip(cur_sum); |
| 174 | + CUTLASS_PRAGMA_UNROLL |
| 175 | + for (int z = 0; z < FragsN; z++) { |
| 176 | + out(x, y, z) *= cur_scale; |
| 177 | + } |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + // Indexing variables |
| 182 | + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; |
| 183 | + // Represent the full output tensor |
| 184 | + Tensor mO_mnl = cute::get_pvc_tensor(make_shape(seq_len_qo, head_size_vo, (is_var_len ? batch : 1) * num_heads_q)); |
| 185 | + |
| 186 | + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; |
| 187 | + // Tile the output tensor per WG |
| 188 | + Tensor g_wg_O = local_tile(mO_mnl, select<0,1>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N,m,n,l) |
| 189 | + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); |
| 190 | + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); |
| 191 | + auto m_sg = get_sub_group_id() / ATOM_N; |
| 192 | + auto n_sg = get_sub_group_id() % ATOM_N; |
| 193 | + // Tile the output tensor per SG |
| 194 | + Tensor gO = local_tile(g_wg_O, SubgroupTileShape{}, make_coord(m_sg,n_sg,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) |
| 195 | + |
| 196 | + auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX()); |
| 197 | + Tensor tOgO = thread_xe_store_o.partition_D(gO); |
| 198 | + |
| 199 | + copy(params.xe_store_o, out, tOgO); |
| 200 | + } |
| 201 | + |
| 202 | + template <bool VarLen, class ProblemShapeType> |
| 203 | + CUTLASS_DEVICE static constexpr Params get_updated_copies(Params const& params, ProblemShapeType const& problem_shape, int const& l_coord) { |
| 204 | + if constexpr (!VarLen) { |
| 205 | + return params; |
| 206 | + } else { |
| 207 | + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; |
| 208 | + |
| 209 | + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; |
| 210 | + int offset_o = num_heads_q * head_size_vo * qo_cumulative_length[l_coord]; |
| 211 | + auto store_traits = static_cast<traits_store_O const&>(params.xe_store_o); |
| 212 | + |
| 213 | + ElementO* base_ptr = (ElementO*)store_traits.base_ptr; |
| 214 | + auto shape_o = make_shape(static_cast<int>(seq_len_qo), head_size_vo, num_heads_q); |
| 215 | + StrideO stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_o); |
| 216 | + |
| 217 | + auto tensorO = make_tensor(make_gmem_ptr(base_ptr + offset_o), make_layout(shape_o, stride_o)); |
| 218 | + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; |
| 219 | + |
| 220 | + return Params{xe_store_o}; |
| 221 | + } |
| 222 | + } |
| 223 | + |
| 224 | +private: |
| 225 | + Params const ¶ms; |
| 226 | +}; |
| 227 | + |
| 228 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
| 229 | + |
| 230 | +} // namespace collective |
| 231 | +} // namespace flash_attention |
| 232 | +} // namespace cutlass |
| 233 | + |
| 234 | +///////////////////////////////////////////////////////////////////////////////////////////////// |
0 commit comments