diff --git a/README.md b/README.md index 9cc10d337..7ce145caa 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,11 @@ variable `MAX_JOBS`: MAX_JOBS=4 pip install flash-attn --no-build-isolation ``` +**Dedicated build(only for research):** +```sh +MAX_JOBS=8 HEADDIM=64 DTYPE=fp16 ENABLE_SM90=FALSE pip install -e . -v +``` + **Interface:** `src/flash_attention_interface.py` ### NVIDIA CUDA Support diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index a57702f6c..9bde503ec 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -76,39 +76,60 @@ #define LOCAL_SWITCH BOOL_SWITCH #endif -#define FP16_SWITCH(COND, ...) \ - [&] { \ - if (COND) { \ - using elem_type = cutlass::half_t; \ - return __VA_ARGS__(); \ - } else { \ - using elem_type = cutlass::bfloat16_t; \ - return __VA_ARGS__(); \ - } \ + +#define DTYPE(COND, cond, dtype, ...) \ + else if (COND == cond) {using elem_type = dtype; return __VA_ARGS__();} + +#if defined(DTYPE_FP16) +#define FP16_SWITCH(COND, ...) [&] { if(false){} DTYPE(COND, true, cutlass::half_t, __VA_ARGS__)}() + +#elif defined(DTYPE_BF16) +#define FP16_SWITCH(COND, ...) [&] { if(false){} DTYPE(COND, false, cutlass::bfloat16_t, __VA_ARGS__)}() + +#else +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (false) {} \ + DTYPE(COND, true, cutlass::half_t, __VA_ARGS__) \ + DTYPE(COND, false, cutlass::bfloat16_t, __VA_ARGS__) \ }() +#endif + + +#define HEAD(HEADDIM, dim, ...) \ + else if (HEADDIM <= dim) {constexpr static int kHeadDim = dim; return __VA_ARGS__();} \ +#if defined(HEADDIM_32) +#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 32, __VA_ARGS__)}() + +#elif defined(HEADDIM_64) +#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 64, __VA_ARGS__)}() + +#elif defined(HEADDIM_96) +#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 96, __VA_ARGS__)}() + +#elif defined(HEADDIM_128) +#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 128, __VA_ARGS__)}() + +#elif defined(HEADDIM_160) +#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 160, __VA_ARGS__)}() + +#elif defined(HEADDIM_192) +#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 192, __VA_ARGS__)}() + +#elif defined(HEADDIM_256) +#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 256, __VA_ARGS__)}() + +#else #define HEADDIM_SWITCH(HEADDIM, ...) \ - [&] { \ - if (HEADDIM <= 32) { \ - constexpr static int kHeadDim = 32; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 64) { \ - constexpr static int kHeadDim = 64; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 96) { \ - constexpr static int kHeadDim = 96; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 128) { \ - constexpr static int kHeadDim = 128; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 160) { \ - constexpr static int kHeadDim = 160; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 192) { \ - constexpr static int kHeadDim = 192; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 256) { \ - constexpr static int kHeadDim = 256; \ - return __VA_ARGS__(); \ - } \ + [&] { \ + if (false) {} \ + HEAD(HEADDIM, 32, __VA_ARGS__) \ + HEAD(HEADDIM, 64, __VA_ARGS__) \ + HEAD(HEADDIM, 96, __VA_ARGS__) \ + HEAD(HEADDIM, 128, __VA_ARGS__) \ + HEAD(HEADDIM, 160, __VA_ARGS__) \ + HEAD(HEADDIM, 192, __VA_ARGS__) \ + HEAD(HEADDIM, 256, __VA_ARGS__) \ }() +#endif \ No newline at end of file diff --git a/setup.py b/setup.py index a567063fb..d997456ab 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ from pathlib import Path from packaging.version import parse, Version import platform +import itertools from setuptools import setup, find_packages import subprocess @@ -62,6 +63,35 @@ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +ENABLE_SM90 = os.getenv("ENABLE_SM90", "TRUE") == "TRUE" + +METHOD = ['fwd', 'fwd_split', 'bwd'] +HEADDIM = [32, 64, 96, 128, 160, 192, 256] +DTYPE = ['bf16', 'fp16'] +CAUSAL = [True, False] + +HEADDIM_FLAG="-DHEADDIM_ALL" +DTYPE_FLAG="-DDTYPE_ALL" + +methods = METHOD + +headdims = os.getenv('HEADDIM') +if headdims: + headdims = [int(headdims)] + HEADDIM_FLAG="-DHEADDIM_" + str(headdims[0]) +else: + headdims = HEADDIM + +dtypes = os.getenv('DTYPE') +if dtypes: + dtypes = [dtypes] + DTYPE_FLAG="-DDTYPE_" + str(dtypes[0]).upper() +else: + dtypes = DTYPE + +print(f"{HEADDIM_FLAG=} {DTYPE_FLAG=}") + +causals = CAUSAL def get_platform(): """ @@ -163,7 +193,7 @@ def validate_and_update_archs(archs): # cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") - if CUDA_HOME is not None: + if CUDA_HOME is not None and ENABLE_SM90: if bare_metal_version >= Version("11.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") @@ -173,101 +203,36 @@ def validate_and_update_archs(archs): # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True + + sources = ['csrc/flash_attn/flash_api.cpp'] + for method, headdim, dtype, causal in itertools.product(methods, headdims, dtypes, causals): + assert method in METHOD, f"{method} not supported" + assert headdim in HEADDIM, f"{headdim} not supported" + assert dtype in DTYPE, f"{dtype} not supported" + assert causal in CAUSAL, f"{causal} not supported" + filename = '' + if causal: + filename = f"csrc/flash_attn/src/flash_{method}_hdim{headdim}_{dtype}_causal"+"_sm80.cu" + else: + filename = f"csrc/flash_attn/src/flash_{method}_hdim{headdim}_{dtype}"+"_sm80.cu" + sources.append(filename) + print("\n\nsources = {}\n\n".format(sources)) + ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", - sources=[ - "csrc/flash_attn/flash_api.cpp", - "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", - "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", - ], + sources=sources, extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], + "cxx": ["-O3", "-std=c++17", + HEADDIM_FLAG, + DTYPE_FLAG, + ], "nvcc": append_nvcc_threads( [ "-O3", + HEADDIM_FLAG, + DTYPE_FLAG, + "-DDTYPE_FP16", "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__",