Skip to content

Commit

Permalink
Support flash-attention custom call (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
ApsarasX authored May 8, 2024
1 parent b43d42d commit 5628d13
Show file tree
Hide file tree
Showing 21 changed files with 4,056 additions and 1 deletion.
Empty file added third_party/flash_attn/BUILD
Empty file.
95 changes: 95 additions & 0 deletions third_party/flash_attn/flash_attn.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda_is_configured")

package(default_visibility = ["//visibility:public"])

licenses(["notice"])

cuda_library(
name = "flash_attn",
hdrs = if_cuda_is_configured([
"csrc/flash_attn/src/alibi.h",
"csrc/flash_attn/src/block_info.h",
"csrc/flash_attn/src/dropout.h",
"csrc/flash_attn/src/flash_bwd_kernel.h",
"csrc/flash_attn/src/flash_bwd_launch_template.h",
"csrc/flash_attn/src/flash_bwd_preprocess_kernel.h",
"csrc/flash_attn/src/flash_fwd_kernel.h",
"csrc/flash_attn/src/flash_fwd_launch_template.h",
"csrc/flash_attn/src/flash_utils.h",
"csrc/flash_attn/src/flash.h",
"csrc/flash_attn/src/kernel_traits.h",
"csrc/flash_attn/src/mask.h",
"csrc/flash_attn/src/philox.cuh",
"csrc/flash_attn/src/rotary.h",
"csrc/flash_attn/src/softmax.h",
"csrc/flash_attn/src/static_switch.h",
"csrc/flash_attn/src/utils.h",
]),
srcs = if_cuda_is_configured([
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu.cc",
"csrc/flash_attn/src/utils.cc",
]),
# https://github.com/Dao-AILab/flash-attention/blob/v2.5.7/setup.py#L193-L199
copts = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
],
include_prefix = "flash_attn",
strip_include_prefix = "csrc/flash_attn/src",
deps = if_cuda_is_configured([
"@cutlass_for_flash_attn//:cutlass",
"@local_config_cuda//cuda:cuda_headers",
]),
)
Loading

0 comments on commit 5628d13

Please sign in to comment.