Skip to content

Commit

Permalink
Merge pull request #7 from intelligent-machine-learning/wengang/flash…
Browse files Browse the repository at this point in the history
…-attention-rebase

Support flash-attention custom call
  • Loading branch information
zjjott authored Apr 16, 2024
2 parents 9225332 + 1369b27 commit 4370014
Show file tree
Hide file tree
Showing 25 changed files with 4,546 additions and 4 deletions.
37 changes: 37 additions & 0 deletions third_party/cutlass.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Description:
# CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance
# matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA.

package(
default_visibility = ["//visibility:public"],
)

licenses(["notice"]) # MIT

exports_files(["LICENSE.txt"])

filegroup(
name = "cutlass_header_files",
srcs = glob([
"include/**",
]),
)

filegroup(
name = "cutlass_util_header_files",
srcs = glob([
"tools/util/include/**",
]),
)

cc_library(
name = "cutlass",
hdrs = [
":cutlass_header_files",
":cutlass_util_header_files",
],
includes = [
"include",
"tools/util/include",
],
)
Empty file added third_party/flash_attn/BUILD
Empty file.
95 changes: 95 additions & 0 deletions third_party/flash_attn/flash_attn.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda_is_configured")

package(default_visibility = ["//visibility:public"])

licenses(["notice"])

cuda_library(
name = "flash_attn",
hdrs = if_cuda_is_configured([
"csrc/flash_attn/src/alibi.h",
"csrc/flash_attn/src/block_info.h",
"csrc/flash_attn/src/dropout.h",
"csrc/flash_attn/src/flash_bwd_kernel.h",
"csrc/flash_attn/src/flash_bwd_launch_template.h",
"csrc/flash_attn/src/flash_bwd_preprocess_kernel.h",
"csrc/flash_attn/src/flash_fwd_kernel.h",
"csrc/flash_attn/src/flash_fwd_launch_template.h",
"csrc/flash_attn/src/flash_utils.h",
"csrc/flash_attn/src/flash.h",
"csrc/flash_attn/src/kernel_traits.h",
"csrc/flash_attn/src/mask.h",
"csrc/flash_attn/src/philox.cuh",
"csrc/flash_attn/src/rotary.h",
"csrc/flash_attn/src/softmax.h",
"csrc/flash_attn/src/static_switch.h",
"csrc/flash_attn/src/utils.h",
]),
srcs = if_cuda_is_configured([
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu.cc",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu.cc",
"csrc/flash_attn/src/utils.cc",
]),
# https://github.com/Dao-AILab/flash-attention/blob/v2.5.7/setup.py#L193-L199
copts = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
],
include_prefix = "flash_attn",
strip_include_prefix = "csrc/flash_attn/src",
deps = if_cuda_is_configured([
"@cutlass_for_flash_attn//:cutlass",
"@local_config_cuda//cuda:cuda_headers",
]),
)
Loading

0 comments on commit 4370014

Please sign in to comment.