Skip to content

Commit

Permalink
added support for small page size. (#53)
Browse files Browse the repository at this point in the history
*apply change from pull request :
Dao-AILab/flash-attention#824
guocuimi authored Feb 27, 2024
1 parent 11dff3f commit 23362e2
Showing 5 changed files with 173 additions and 62 deletions.
2 changes: 1 addition & 1 deletion src/kernels/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
@@ -276,7 +276,7 @@ mha_varlen_fwd(at::Tensor &q, // [n_tokens, n_heads, head_dim]
const int n_blocks = !paged_KV ? 0 : k.size(0);
const int block_size = !paged_KV ? 1 : k.size(1);
// TODO: support smaller block sizes
TORCH_CHECK(!paged_KV || block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
TORCH_CHECK(!paged_KV || block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");

// [n_tokens, n_heads, head_dim]
const auto sizes = q.sizes();
136 changes: 76 additions & 60 deletions src/kernels/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
@@ -515,16 +515,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
const index_t row_offset_k = block_table == nullptr
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
: (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread

const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
: (bidh / params.h_h_k_ratio) * params.v_head_stride;

Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
@@ -544,15 +543,30 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});

typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);

Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV;
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);

Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);

Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV);

Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout()));
Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout()));
Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout()));
Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout()));

if (block_table != nullptr) {
tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}

typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
@@ -590,8 +604,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)

// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tKVcKV = make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout()));

// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
@@ -608,11 +623,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Prologue

// Copy from Knew to K, optionally apply rotary embedding.
typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
if constexpr (Append_KV) {
typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinContPaged gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);

// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
@@ -629,10 +645,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);

Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont);

Tensor tRgCos = make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout()));
Tensor tRgSin = make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout()));
Tensor tRgCosCont = make_tensor(tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout()));
Tensor tRgSinCont = make_tensor(tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout()));

// if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
@@ -653,8 +676,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.vnew_row_stride, _1{}));
Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new;
auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx);
Tensor tKgKnew_ = gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tVgVnew_ = gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)

auto tKgKnew = make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout()));
auto tVgVnew = make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout()));

const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
auto tKgK_data = tKgK.data();
@@ -694,14 +722,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
if (n_block > n_block_copy_min) {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
}
}
@@ -714,9 +738,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Read Q from gmem to smem, optionally apply rotary embedding.
if (!Append_KV || params.rotary_dim == 0) {
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
} else {
typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
@@ -751,7 +779,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons

int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();

@@ -790,17 +818,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();
@@ -825,13 +850,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
@@ -868,13 +890,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();

flash::gemm(
@@ -889,13 +908,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
20 changes: 20 additions & 0 deletions src/kernels/flash_attn/src/kernel_traits.h
Original file line number Diff line number Diff line change
@@ -127,6 +127,18 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read

// from how many rows does each thread have to fetch
static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow);
// Here we assign a contiguous tile to each thread, rather than a 1x8 row every
// (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread
// do not cross a page boundary. This way, each thread need only fetch 1 page index per
// mainloop iteration. R>udimentary testing shows no slowdown.
using GmemTiledCopyQKVPaged = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));

using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
@@ -152,6 +164,14 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
using GmemTiledCopyRotcossinPaged = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinContPaged = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load
};

////////////////////////////////////////////////////////////////////////////////////////////////////
75 changes: 75 additions & 0 deletions src/kernels/flash_attn/src/utils.h
Original file line number Diff line number Diff line change
@@ -379,4 +379,79 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S

////////////////////////////////////////////////////////////////////////////////////////////////////

// resolves initial base offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits>
__forceinline__ __device__
int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN;

const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int page_offset = global_row_offset % page_block_size;
const int virtual_page_idx = global_row_offset / page_block_size;

return block_table[virtual_page_idx] * page_stride
+ page_offset * row_stride
+ col_offset;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// advances base address of a slice of a paged copy from gmem
template <typename Kernel_traits>
__forceinline__ __device__
int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kBlockN = Kernel_traits::kBlockN;

const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;

const int global_row_offset_cur = block_row_offset + n_block * kBlockN;
const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN;

const int page_offset_cur = global_row_offset_cur % page_block_size;
const int page_offset_next = global_row_offset_next % page_block_size;

const int virtual_page_idx_cur = global_row_offset_cur / page_block_size;
const int virtual_page_idx_next = global_row_offset_next / page_block_size;

const int table_diff = block_table[virtual_page_idx_next] - block_table[virtual_page_idx_cur];
const int offset_diff = page_offset_next - page_offset_cur;

return table_diff * page_stride + offset_diff * row_stride;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k),
// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures
// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors.
template <class Shape, class Stride>
__forceinline__ __device__
auto reshape_thread_tile(Layout<Shape, Stride> l) {
return make_layout(append(get<0>(l.shape()), get<2>(l.shape())),
append(get<0>(l.stride()), get<2>(l.stride())));
}

// reshapes and flattens the thread tile layout. A separate function is needed for the case where
// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact
// for the case of swizzled layouts
template <class Shape, class Stride>
__forceinline__ __device__
auto reshape_flatten_thread_tile(Layout<Shape, Stride> l) {
auto mode_0 = filter(flatten(get<0>(l)));
return make_layout(append(mode_0.shape(), get<2>(l.shape())),
append(mode_0.stride(), get<2>(l.stride())));
}

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace flash
2 changes: 1 addition & 1 deletion src/layers/attention_test.cpp
Original file line number Diff line number Diff line change
@@ -322,7 +322,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(torch::kCUDA),
::testing::Values(torch::kHalf, torch::kBFloat16),
::testing::Values(1, 10), // batch_size
::testing::Values(256), // block_size
::testing::Values(16, 80, 256), // block_size
::testing::Values(1, 10), // q_max_seq_len
::testing::Values(100, 1000), // k_max_seq_len
::testing::Values(6), // n_heads

0 comments on commit 23362e2

Please sign in to comment.