Skip to content

Commit

Permalink
modify seqlentraits for gqa parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
jayhshah committed Aug 19, 2024
1 parent 62f4fe9 commit 4375324
Showing 1 changed file with 50 additions and 7 deletions.
57 changes: 50 additions & 7 deletions hopper/seq_len.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@

#pragma once

#include <array>
#include <algorithm>

#include <cutlass/cutlass.h>
#include <cute/layout.hpp>

namespace flash {

static constexpr int kMaxTileSize = 128;

static constexpr int FixedSeqLenType = 0;
static constexpr int VarSeqLenType = 1;
static constexpr int DecodingGQASeqLenType = 2;

template <int SeqLenType> class SeqLenTraits {
public:
static_assert(SeqLenType == 0 || SeqLenType == 1 || SeqLenType == 2,
Expand All @@ -27,16 +34,25 @@ template <int SeqLenType> class SeqLenTraits {

// Whether this is for fixed-seq-len or var-seq-len.
static constexpr bool UseVarSeqLen = SeqLenType == 1;
static constexpr bool DecodingGQA = SeqLenType == 2;

using ShapeT = std::conditional_t<
UseVarSeqLen,
cute::Shape<int32_t, int32_t, int32_t>,
cute::Shape<int32_t, int32_t, int32_t, int32_t>
cute::Shape<int32_t, int32_t, int32_t>,
std::conditional_t<
DecodingGQA,
cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>,
cute::Shape<int32_t, int32_t, int32_t, int32_t>
>
>;
using StrideT = std::conditional_t<
UseVarSeqLen,
cute::Shape<int64_t, _1, int64_t>,
cute::Shape<int64_t, _1, int64_t, int64_t>
std::conditional_t<
DecodingGQA,
cute::Shape<int64_t, int64_t, _1, int64_t, int64_t>,
cute::Shape<int64_t, _1, int64_t, int64_t>
>
>;
using LayoutT = cute::Layout<ShapeT, StrideT>;

Expand Down Expand Up @@ -64,7 +80,7 @@ template <int SeqLenType> class SeqLenTraits {
int m, int k, int h, int b,
int64_t m_stride, int64_t h_stride, int64_t b_stride,
bool padded = false) const {
static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen.");
return make_layout(make_shape(m, k, h, b),
make_stride(m_stride, cute::_1{}, h_stride, b_stride));
}
Expand All @@ -73,7 +89,7 @@ template <int SeqLenType> class SeqLenTraits {
// padded: only useful for var-seq-len for dq_accum and softmax_d.
CUTLASS_HOST_DEVICE auto get_lse_gmem_layout(
int m, int h, int b, bool padded = false) const {
static_assert(!UseVarSeqLen, "Default implementation is for FixedSeqLen.");
static_assert(!UseVarSeqLen, "Specialize default implementation for VarSeqLen.");
return make_layout(make_shape(b, h, m),
make_stride(int64_t(h * m), int64_t(m), cute::_1()));
}
Expand Down Expand Up @@ -106,8 +122,35 @@ template <int SeqLenType> class SeqLenTraits {
}
};

using FixedSeqLenTraits = SeqLenTraits<0>;
using VarSeqLenTraits = SeqLenTraits<1>;
template <>
class SeqLenTraits<DecodingGQASeqLenType> {
public:

// Returns the layout of QO tensor in (M,H/HK,K,HK,B) format in global memory.
CUTLASS_HOST_DEVICE auto get_query_gmem_layout(
int m, int k, int h_k, int b, int h_h_k_ratio,
int64_t m_stride, int64_t h_stride, int64_t b_stride) const {
return make_layout(make_shape(m, h_h_k_ratio, k, h_k, b),
make_stride(m_stride, h_stride, cute::_1{},
h_stride * h_h_k_ratio, b_stride));
}

// Tile Shape should be (bM/HQ, bHQ, bK)
// Returns local tile (bM/HQ, bHQ, bK)
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_query_local_tile_tensor(
const MTensor &m_tensor, const Shape &tile_shape,
int bidh_kv, int bidb, int m_block, int bidh_local) const {
// expect bidh_local = bidh % qhead_per_khead
auto g_tensor = local_tile(
m_tensor(_, _, _, bidh_kv, bidb), tile_shape, make_coord(m_block, bidh_local, _0{}));
return g_tensor;
}
};

using FixedSeqLenTraits = SeqLenTraits<FixedSeqLenType>;
using VarSeqLenTraits = SeqLenTraits<VarSeqLenType>;
using DecodingGQASeqLenTraits = SeqLenTraits<DecodingGQASeqLenType>;

// Returns the static layout of a var-seq-len tensor in global memory based on
// max_seq_len and max_batch_size.
Expand Down

0 comments on commit 4375324

Please sign in to comment.