From b47b419e6f7b4ae1c4067b3b23c17a4cb08d48ab Mon Sep 17 00:00:00 2001 From: skrider Date: Tue, 26 Mar 2024 01:33:51 +0000 Subject: [PATCH] resolve page offsets absolutely not relatively --- csrc/flash_attn/src/flash_fwd_kernel.h | 16 ++++++------- csrc/flash_attn/src/utils.h | 32 ++------------------------ 2 files changed, 10 insertions(+), 38 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 9ae7dd27e..34922d519 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -609,9 +609,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout())); if (block_table != nullptr) { - tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); - tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } @@ -769,9 +769,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { - tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); - tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } } @@ -865,7 +865,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); @@ -897,7 +897,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); @@ -937,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - tVgV.data() = tVgV.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, block_table, params.v_batch_stride, params.v_row_stride); } flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); @@ -955,7 +955,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - tKgK.data() = tKgK.data() + flash::advance_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, block_table, params.k_batch_stride, params.k_row_stride); } flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 46b2ea039..4f999a6b7 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -292,11 +292,11 @@ void cp_async_wait() { //////////////////////////////////////////////////////////////////////////////////////////////////// -// resolves initial base offset of a slice of a paged kv copy from gmem. +// resolves offset of a slice of a paged kv copy from gmem. // assumes that the tensor has already been positioned at the correct head. template __forceinline__ __device__ -int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, +int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, const int* block_table, const int page_stride, const int row_stride) { constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; @@ -313,34 +313,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons + page_offset * row_stride + col_offset; } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// advances base address of a slice of a paged copy from gmem -template -__forceinline__ __device__ -int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const int page_block_size, - const int* block_table, const int page_stride, const int row_stride) { - constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; - constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; - constexpr int kBlockN = Kernel_traits::kBlockN; - - const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; - - const int global_row_offset_cur = block_row_offset + n_block * kBlockN; - const int global_row_offset_next = block_row_offset + (n_block - 1) * kBlockN; - - const int page_offset_cur = global_row_offset_cur % page_block_size; - const int page_offset_next = global_row_offset_next % page_block_size; - - const int virtual_page_idx_cur = global_row_offset_cur / page_block_size; - const int virtual_page_idx_next = global_row_offset_next / page_block_size; - - const int table_diff = block_table[virtual_page_idx_next] - block_table[virtual_page_idx_cur]; - const int offset_diff = page_offset_next - page_offset_cur; - - return table_diff * page_stride + offset_diff * row_stride; -} ////////////////////////////////////////////////////////////////////////////////////////////////////