diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 773c56388..2a3b7eb17 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -6,7 +6,7 @@ use rayon::prelude::*; use std::path::PathBuf; use std::str::FromStr; -const KERNEL_FILES: [&str; 9] = [ +const KERNEL_FILES: [&str; 17] = [ "flash_api.cu", "flash_fwd_hdim128_fp16_sm80.cu", "flash_fwd_hdim160_fp16_sm80.cu", @@ -16,14 +16,14 @@ const KERNEL_FILES: [&str; 9] = [ "flash_fwd_hdim32_fp16_sm80.cu", "flash_fwd_hdim64_fp16_sm80.cu", "flash_fwd_hdim96_fp16_sm80.cu", - // "flash_fwd_hdim128_bf16_sm80.cu", - // "flash_fwd_hdim160_bf16_sm80.cu", - // "flash_fwd_hdim192_bf16_sm80.cu", - // "flash_fwd_hdim224_bf16_sm80.cu", - // "flash_fwd_hdim256_bf16_sm80.cu", - // "flash_fwd_hdim32_bf16_sm80.cu", - // "flash_fwd_hdim64_bf16_sm80.cu", - // "flash_fwd_hdim96_bf16_sm80.cu", + "flash_fwd_hdim128_bf16_sm80.cu", + "flash_fwd_hdim160_bf16_sm80.cu", + "flash_fwd_hdim192_bf16_sm80.cu", + "flash_fwd_hdim224_bf16_sm80.cu", + "flash_fwd_hdim256_bf16_sm80.cu", + "flash_fwd_hdim32_bf16_sm80.cu", + "flash_fwd_hdim64_bf16_sm80.cu", + "flash_fwd_hdim96_bf16_sm80.cu", ]; fn main() -> Result<()> { diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index d928bcb60..72991257a 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -1,20 +1,19 @@ #include "flash_fwd_launch_template.h" -// TODO: Switch back to handling bf16. -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_(params, stream); - }); -} - // void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { -// FP16_SWITCH(!params.is_bf16, [&] { -// FWD_HEADDIM_SWITCH(params.d, [&] { -// run_mha_fwd_(params, stream); -// }); +// FWD_HEADDIM_SWITCH(params.d, [&] { +// run_mha_fwd_(params, stream); // }); // } +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_(params, stream); + }); + }); +} + extern "C" void run_mha( void *q_ptr, void *k_ptr, @@ -52,7 +51,8 @@ extern "C" void run_mha( uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, - int is_causal + int is_causal, + int is_bf16 ) { Flash_fwd_params params; // Reset the parameters @@ -102,7 +102,7 @@ extern "C" void run_mha( params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - params.is_bf16 = 0; + params.is_bf16 = is_bf16; params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr; params.p_ptr = nullptr; // used for `return_softmax`. diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index ae61c405b..90f34e434 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -38,6 +38,7 @@ extern "C" { seqlen_k_rounded: u32, is_causal: c_int, + is_bf16: c_int, ); } diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 3c5fd4550..cdb4b083f 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -146,6 +146,7 @@ impl candle::CustomOp3 for FlashAttn { /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, /* is_causal */ causal, + /* is_bf16 */ 0, ) } @@ -354,6 +355,7 @@ impl candle::CustomOp3 for FlashAttnVarLen { /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, /* is_causal */ causal, + /* is_bf16 */ 0, ) }