-
Notifications
You must be signed in to change notification settings - Fork 463
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* first * fix causal mask * disable flash attention2 on sm70 * fix 2 * update readme * clang-format * disable ft2 on windows * fix lint * fix build * fix build * fix long kv seq * fix lint * sync copy output --------- Co-authored-by: grimoire <[email protected]> Co-authored-by: irexyc <[email protected]>
- Loading branch information
1 parent
d4d609b
commit 452822a
Showing
24 changed files
with
1,923 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 15 additions & 0 deletions
15
src/turbomind/models/llama/flash_attention2/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
cmake_minimum_required(VERSION 3.8) | ||
project(flash_attention2) | ||
|
||
add_library(${PROJECT_NAME} STATIC | ||
flash_api.cpp | ||
flash_fwd_hdim32_fp16_sm80.cu | ||
flash_fwd_hdim64_fp16_sm80.cu | ||
flash_fwd_hdim128_fp16_sm80.cu | ||
flash_fwd_hdim256_fp16_sm80.cu | ||
) | ||
target_include_directories(${PROJECT_NAME} PRIVATE ${CUTLASS_DIR} / include) | ||
target_link_libraries(${PROJECT_NAME} PRIVATE nvidia::cutlass::cutlass) | ||
set_property(TARGET ${PROJECT_NAME} PROPERTY POSITION_INDEPENDENT_CODE ON) | ||
set_property(TARGET ${PROJECT_NAME} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#Flash Attention 2 | ||
|
||
This is flash attention2 implementation modified from https://github.com/Dao-AILab/flash-attention | ||
|
||
- remove dropout | ||
- remove backward | ||
- cutlass 3.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2023, Tri Dao. | ||
******************************************************************************/ | ||
|
||
#pragma once | ||
|
||
namespace flash { | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
template<bool Varlen = true> | ||
struct BlockInfo { | ||
|
||
template<typename Params> | ||
__device__ BlockInfo(const Params& params, const int bidb): | ||
sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), | ||
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]), | ||
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : | ||
params.cu_seqlens_q[bidb + 1] - sum_s_q), | ||
actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : | ||
params.cu_seqlens_k[bidb + 1] - sum_s_k) | ||
{ | ||
} | ||
|
||
template<typename index_t> | ||
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const | ||
{ | ||
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; | ||
} | ||
|
||
template<typename index_t> | ||
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const | ||
{ | ||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; | ||
} | ||
|
||
const int sum_s_q; | ||
const int sum_s_k; | ||
const int actual_seqlen_q; | ||
const int actual_seqlen_k; | ||
}; | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
} // namespace flash |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2023, Tri Dao. | ||
******************************************************************************/ | ||
// modify from: https://github.com/Dao-AILab/flash-attention | ||
|
||
#pragma once | ||
|
||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
#include <vector> | ||
|
||
constexpr int TOTAL_DIM = 0; | ||
constexpr int H_DIM = 1; | ||
constexpr int D_DIM = 2; | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
struct Qkv_params { | ||
using index_t = uint32_t; | ||
// The QKV matrices. | ||
void* __restrict__ q_ptr; | ||
void* __restrict__ k_ptr; | ||
void* __restrict__ v_ptr; | ||
|
||
// batched ptr inputs. | ||
void** __restrict__ k_batched_ptr = nullptr; | ||
void** __restrict__ v_batched_ptr = nullptr; | ||
int k_batched_offset = 0; | ||
int v_batched_offset = 0; | ||
|
||
// The stride between rows of the Q, K and V matrices. | ||
index_t q_batch_stride; | ||
index_t k_batch_stride; | ||
index_t v_batch_stride; | ||
index_t q_row_stride; | ||
index_t k_row_stride; | ||
index_t v_row_stride; | ||
index_t q_head_stride; | ||
index_t k_head_stride; | ||
index_t v_head_stride; | ||
|
||
// The number of heads. | ||
int h, h_k; | ||
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be | ||
// different from nheads (query). | ||
int h_h_k_ratio; // precompute h / h_k, | ||
}; | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
struct Flash_fwd_params: public Qkv_params { | ||
|
||
// The O matrix (output). | ||
void* __restrict__ o_ptr; | ||
|
||
// The stride between rows of O. | ||
index_t o_batch_stride; | ||
index_t o_row_stride; | ||
index_t o_head_stride; | ||
|
||
// The pointer to the P matrix. | ||
void* __restrict__ p_ptr; | ||
|
||
// The dimensions. | ||
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded; | ||
|
||
// The scaling factors for the kernel. | ||
float scale_softmax; | ||
float scale_softmax_log2; | ||
|
||
// array of length b+1 holding starting offset of each sequence. | ||
int* __restrict__ cu_seqlens_q; | ||
int* __restrict__ cu_seqlens_k; | ||
|
||
void* __restrict__ blockmask; | ||
|
||
bool is_bf16; | ||
bool is_causal; | ||
|
||
// enable output seqlen | ||
bool q_enable_seqlen; | ||
bool o_enable_seqlen; | ||
}; | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
template<typename T, int Headdim> | ||
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); |
164 changes: 164 additions & 0 deletions
164
src/turbomind/models/llama/flash_attention2/flash_api.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2023, Tri Dao. | ||
******************************************************************************/ | ||
// modify from: https://github.com/Dao-AILab/flash-attention | ||
|
||
#include "flash.h" | ||
#include "src/turbomind/models/llama/llama_kernels.h" | ||
#include "static_switch.h" | ||
#include <cuda_runtime.h> | ||
#include <cutlass/numeric_types.h> | ||
#include <math.h> | ||
|
||
void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) | ||
{ | ||
FP16_SWITCH(true, | ||
[&] { FWD_HEADDIM_SWITCH(params.d, [&] { run_mha_fwd_<elem_type, kHeadDim>(params, stream); }); }); | ||
} | ||
|
||
namespace turbomind { | ||
|
||
static constexpr int FMHA_VERSION = 2; | ||
|
||
template<typename T> | ||
class FlashAttentionOpImpl<T, FMHA_VERSION> { | ||
|
||
public: | ||
using AttentionLayout = BaseAttentionLayout<T>; | ||
using Params = BaseAttentionParams<T>; | ||
|
||
public: | ||
FlashAttentionOpImpl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head); | ||
~FlashAttentionOpImpl(); | ||
|
||
int get_workspace_size() const; | ||
|
||
void operator()(Params& params, cudaStream_t st) const; | ||
|
||
private: | ||
class impl; | ||
std::unique_ptr<impl> pimpl; | ||
}; | ||
|
||
template<typename T> | ||
class FlashAttentionOpImpl<T, FMHA_VERSION>::impl { | ||
|
||
private: | ||
using scalar_t = | ||
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>; | ||
using Params = typename FlashAttentionOpImpl<T, FMHA_VERSION>::Params; | ||
|
||
int batch_size_; | ||
int head_num_; | ||
int key_len_; | ||
int seq_len_; | ||
int size_per_head_; | ||
|
||
public: | ||
impl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head): | ||
batch_size_(batch_size), | ||
head_num_(head_num), | ||
key_len_(key_len), | ||
seq_len_(seq_len), | ||
size_per_head_(size_per_head) | ||
{ | ||
} | ||
|
||
~impl() {} | ||
|
||
int get_workspace_size() const | ||
{ | ||
return 0; | ||
} | ||
|
||
void operator()(Params& params, cudaStream_t st) const | ||
{ | ||
const float qk_scale = static_cast<float>(1.f / sqrtf(size_per_head_ * 1.f)); | ||
Flash_fwd_params fwd_params; | ||
memset(&fwd_params, 0, sizeof(fwd_params)); | ||
|
||
fwd_params.q_ptr = reinterpret_cast<void*>(params.query); | ||
fwd_params.k_ptr = reinterpret_cast<void*>(params.key); | ||
fwd_params.v_ptr = reinterpret_cast<void*>(params.val); | ||
|
||
fwd_params.k_batched_ptr = reinterpret_cast<void**>(params.layout_k.batch_seqs); | ||
fwd_params.v_batched_ptr = reinterpret_cast<void**>(params.layout_v.batch_seqs); | ||
fwd_params.k_batched_offset = params.layout_k.batch_seqs_offset; | ||
fwd_params.v_batched_offset = params.layout_v.batch_seqs_offset; | ||
|
||
fwd_params.q_batch_stride = params.layout_q.stride_batch; | ||
fwd_params.k_batch_stride = params.layout_k.stride_batch; | ||
fwd_params.v_batch_stride = params.layout_v.stride_batch; | ||
fwd_params.q_row_stride = params.layout_q.stride_seq; | ||
fwd_params.k_row_stride = params.layout_k.stride_seq; | ||
fwd_params.v_row_stride = params.layout_v.stride_seq; | ||
fwd_params.q_head_stride = params.layout_q.stride_head; | ||
fwd_params.v_head_stride = params.layout_v.stride_head; | ||
fwd_params.k_head_stride = params.layout_k.stride_head; | ||
|
||
fwd_params.h = head_num_; | ||
fwd_params.h_k = head_num_ / params.group_size; | ||
fwd_params.h_h_k_ratio = params.group_size; | ||
|
||
fwd_params.o_ptr = reinterpret_cast<void*>(params.attn_out); | ||
|
||
fwd_params.o_batch_stride = params.layout_o.stride_batch; | ||
fwd_params.o_row_stride = params.layout_o.stride_seq; | ||
fwd_params.o_head_stride = params.layout_o.stride_head; | ||
|
||
fwd_params.p_ptr = nullptr; | ||
|
||
fwd_params.b = batch_size_; | ||
fwd_params.seqlen_q = seq_len_; | ||
fwd_params.seqlen_k = key_len_; | ||
fwd_params.d = size_per_head_; | ||
fwd_params.seqlen_q_rounded = 0; | ||
fwd_params.seqlen_k_rounded = 0; | ||
|
||
fwd_params.scale_softmax = qk_scale; | ||
fwd_params.scale_softmax_log2 = qk_scale * M_LOG2E; | ||
|
||
fwd_params.cu_seqlens_q = params.cu_seqlens_q; | ||
fwd_params.cu_seqlens_k = params.cu_seqlens_k; | ||
|
||
fwd_params.blockmask = reinterpret_cast<void*>(params.mask); | ||
|
||
fwd_params.is_bf16 = false; | ||
fwd_params.is_causal = true; | ||
|
||
fwd_params.q_enable_seqlen = params.layout_q.use_seqlens; | ||
fwd_params.o_enable_seqlen = params.layout_o.use_seqlens; | ||
|
||
run_mha_fwd(fwd_params, st); | ||
} | ||
}; | ||
|
||
template<typename T> | ||
FlashAttentionOpImpl<T, FMHA_VERSION>::FlashAttentionOpImpl( | ||
int batch_size, int head_num, int key_len, int seq_len, int size_per_head): | ||
pimpl{std::make_unique<FlashAttentionOpImpl<T, FMHA_VERSION>::impl>( | ||
batch_size, head_num, key_len, seq_len, size_per_head)} | ||
{ | ||
} | ||
|
||
template<typename T> | ||
FlashAttentionOpImpl<T, FMHA_VERSION>::~FlashAttentionOpImpl() | ||
{ | ||
} | ||
|
||
template<typename T> | ||
int FlashAttentionOpImpl<T, FMHA_VERSION>::get_workspace_size() const | ||
{ | ||
return pimpl->get_workspace_size(); | ||
} | ||
|
||
template<typename T> | ||
void FlashAttentionOpImpl<T, FMHA_VERSION>::operator()(Params& params, cudaStream_t st) const | ||
{ | ||
pimpl->operator()(params, st); | ||
} | ||
|
||
template class FlashAttentionOpImpl<float, FMHA_VERSION>; | ||
template class FlashAttentionOpImpl<half, FMHA_VERSION>; | ||
|
||
} // namespace turbomind |
11 changes: 11 additions & 0 deletions
11
src/turbomind/models/llama/flash_attention2/flash_fwd_hdim128_fp16_sm80.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
// Copyright (c) 2023, Tri Dao. | ||
|
||
// Splitting the different head dimensions to different files to speed up compilation. | ||
|
||
#include "flash_fwd_launch_template.h" | ||
|
||
template<> | ||
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params& params, cudaStream_t stream) | ||
{ | ||
run_mha_fwd_hdim128<cutlass::half_t>(params, stream); | ||
} |
11 changes: 11 additions & 0 deletions
11
src/turbomind/models/llama/flash_attention2/flash_fwd_hdim256_fp16_sm80.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
// Copyright (c) 2023, Tri Dao. | ||
|
||
// Splitting the different head dimensions to different files to speed up compilation. | ||
|
||
#include "flash_fwd_launch_template.h" | ||
|
||
template<> | ||
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params& params, cudaStream_t stream) | ||
{ | ||
run_mha_fwd_hdim256<cutlass::half_t>(params, stream); | ||
} |
Oops, something went wrong.