From 295715d8505774062211f80133a12ad2f95a3f86 Mon Sep 17 00:00:00 2001 From: grimoire <streetyao@live.com> Date: Fri, 24 Nov 2023 13:18:45 +0800 Subject: [PATCH] fix --- .../llama/flash_attention2/flash_fwd_launch_template.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/turbomind/models/llama/flash_attention2/flash_fwd_launch_template.h b/src/turbomind/models/llama/flash_attention2/flash_fwd_launch_template.h index 4a94da08b2..9c5acbcd6a 100644 --- a/src/turbomind/models/llama/flash_attention2/flash_fwd_launch_template.h +++ b/src/turbomind/models/llama/flash_attention2/flash_fwd_launch_template.h @@ -14,7 +14,13 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax> __global__ void flash_fwd_kernel(Flash_fwd_params params) { + +#if __CUDA_ARCH__ >= 800 flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params); +#else +// TODO: support flash attention2 on sm<80 + assert(false); +#endif } template<typename Kernel_traits, bool Is_dropout, bool Is_causal>