From 4375324836497c37e565a513e268f125dc13bff4 Mon Sep 17 00:00:00 2001 From: Jay Shah Date: Mon, 19 Aug 2024 13:10:58 -0700 Subject: [PATCH] modify seqlentraits for gqa parallelism --- hopper/seq_len.h | 57 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/hopper/seq_len.h b/hopper/seq_len.h index c7c52c7ee..a7a5a6c73 100644 --- a/hopper/seq_len.h +++ b/hopper/seq_len.h @@ -4,6 +4,9 @@ #pragma once +#include +#include + #include #include @@ -11,6 +14,10 @@ namespace flash { static constexpr int kMaxTileSize = 128; +static constexpr int FixedSeqLenType = 0; +static constexpr int VarSeqLenType = 1; +static constexpr int DecodingGQASeqLenType = 2; + template class SeqLenTraits { public: static_assert(SeqLenType == 0 || SeqLenType == 1 || SeqLenType == 2, @@ -27,16 +34,25 @@ template class SeqLenTraits { // Whether this is for fixed-seq-len or var-seq-len. static constexpr bool UseVarSeqLen = SeqLenType == 1; + static constexpr bool DecodingGQA = SeqLenType == 2; using ShapeT = std::conditional_t< UseVarSeqLen, - cute::Shape, - cute::Shape + cute::Shape, + std::conditional_t< + DecodingGQA, + cute::Shape, + cute::Shape + > >; using StrideT = std::conditional_t< UseVarSeqLen, cute::Shape, - cute::Shape + std::conditional_t< + DecodingGQA, + cute::Shape, + cute::Shape + > >; using LayoutT = cute::Layout; @@ -64,7 +80,7 @@ template class SeqLenTraits { int m, int k, int h, int b, int64_t m_stride, int64_t h_stride, int64_t b_stride, bool padded = false) const { - static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen."); + static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen."); return make_layout(make_shape(m, k, h, b), make_stride(m_stride, cute::_1{}, h_stride, b_stride)); } @@ -73,7 +89,7 @@ template class SeqLenTraits { // padded: only useful for var-seq-len for dq_accum and softmax_d. CUTLASS_HOST_DEVICE auto get_lse_gmem_layout( int m, int h, int b, bool padded = false) const { - static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen."); + static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen."); return make_layout(make_shape(b, h, m), make_stride(int64_t(h * m), int64_t(m), cute::_1())); } @@ -106,8 +122,35 @@ template class SeqLenTraits { } }; -using FixedSeqLenTraits = SeqLenTraits<0>; -using VarSeqLenTraits = SeqLenTraits<1>; +template <> +class SeqLenTraits { +public: + + // Returns the layout of QO tensor in (M,H/HK,K,HK,B) format in global memory. + CUTLASS_HOST_DEVICE auto get_query_gmem_layout( + int m, int k, int h_k, int b, int h_h_k_ratio, + int64_t m_stride, int64_t h_stride, int64_t b_stride) const { + return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b), + make_stride(m_stride, h_stride, cute::_1{}, + h_stride * h_h_k_ratio, b_stride)); + } + + // Tile Shape should be (bM/HQ, bHQ, bK) + // Returns local tile (bM/HQ, bHQ, bK) + template + CUTLASS_DEVICE auto get_query_local_tile_tensor( + const MTensor &m_tensor, const Shape &tile_shape, + int bidh_kv, int bidb, int m_block, int bidh_local) const { + // expect bidh_local = bidh % qhead_per_khead + auto g_tensor = local_tile( + m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(m_block, bidh_local, _0{})); + return g_tensor; + } +}; + +using FixedSeqLenTraits = SeqLenTraits; +using VarSeqLenTraits = SeqLenTraits; +using DecodingGQASeqLenTraits = SeqLenTraits; // Returns the static layout of a var-seq-len tensor in global memory based on // max_seq_len and max_batch_size.