Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dedicated compile[For Research] #1384

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ variable `MAX_JOBS`:
MAX_JOBS=4 pip install flash-attn --no-build-isolation
```

**Dedicated build(only for research):**
```sh
MAX_JOBS=8 HEADDIM=64 DTYPE=fp16 ENABLE_SM90=FALSE pip install -e . -v
```

**Interface:** `src/flash_attention_interface.py`

### NVIDIA CUDA Support
Expand Down
85 changes: 53 additions & 32 deletions csrc/flash_attn/src/static_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,39 +76,60 @@
#define LOCAL_SWITCH BOOL_SWITCH
#endif

#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = cutlass::half_t; \
return __VA_ARGS__(); \
} else { \
using elem_type = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} \

#define DTYPE(COND, cond, dtype, ...) \
else if (COND == cond) {using elem_type = dtype; return __VA_ARGS__();}

#if defined(DTYPE_FP16)
#define FP16_SWITCH(COND, ...) [&] { if(false){} DTYPE(COND, true, cutlass::half_t, __VA_ARGS__)}()

#elif defined(DTYPE_BF16)
#define FP16_SWITCH(COND, ...) [&] { if(false){} DTYPE(COND, false, cutlass::bfloat16_t, __VA_ARGS__)}()

#else
#define FP16_SWITCH(COND, ...) \
[&] { \
if (false) {} \
DTYPE(COND, true, cutlass::half_t, __VA_ARGS__) \
DTYPE(COND, false, cutlass::bfloat16_t, __VA_ARGS__) \
}()
#endif


#define HEAD(HEADDIM, dim, ...) \
else if (HEADDIM <= dim) {constexpr static int kHeadDim = dim; return __VA_ARGS__();} \

#if defined(HEADDIM_32)
#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 32, __VA_ARGS__)}()

#elif defined(HEADDIM_64)
#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 64, __VA_ARGS__)}()

#elif defined(HEADDIM_96)
#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 96, __VA_ARGS__)}()

#elif defined(HEADDIM_128)
#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 128, __VA_ARGS__)}()

#elif defined(HEADDIM_160)
#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 160, __VA_ARGS__)}()

#elif defined(HEADDIM_192)
#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 192, __VA_ARGS__)}()

#elif defined(HEADDIM_256)
#define HEADDIM_SWITCH(HEADDIM, ...) [&]{ if(false){} HEAD(HEADDIM, 256, __VA_ARGS__)}()

#else
#define HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 96) { \
constexpr static int kHeadDim = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 160) { \
constexpr static int kHeadDim = 160; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 192) { \
constexpr static int kHeadDim = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
} \
[&] { \
if (false) {} \
HEAD(HEADDIM, 32, __VA_ARGS__) \
HEAD(HEADDIM, 64, __VA_ARGS__) \
HEAD(HEADDIM, 96, __VA_ARGS__) \
HEAD(HEADDIM, 128, __VA_ARGS__) \
HEAD(HEADDIM, 160, __VA_ARGS__) \
HEAD(HEADDIM, 192, __VA_ARGS__) \
HEAD(HEADDIM, 256, __VA_ARGS__) \
}()
#endif
143 changes: 54 additions & 89 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from packaging.version import parse, Version
import platform
import itertools

from setuptools import setup, find_packages
import subprocess
Expand Down Expand Up @@ -62,6 +63,35 @@
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
ENABLE_SM90 = os.getenv("ENABLE_SM90", "TRUE") == "TRUE"

METHOD = ['fwd', 'fwd_split', 'bwd']
HEADDIM = [32, 64, 96, 128, 160, 192, 256]
DTYPE = ['bf16', 'fp16']
CAUSAL = [True, False]

HEADDIM_FLAG="-DHEADDIM_ALL"
DTYPE_FLAG="-DDTYPE_ALL"

methods = METHOD

headdims = os.getenv('HEADDIM')
if headdims:
headdims = [int(headdims)]
HEADDIM_FLAG="-DHEADDIM_" + str(headdims[0])
else:
headdims = HEADDIM

dtypes = os.getenv('DTYPE')
if dtypes:
dtypes = [dtypes]
DTYPE_FLAG="-DDTYPE_" + str(dtypes[0]).upper()
else:
dtypes = DTYPE

print(f"{HEADDIM_FLAG=} {DTYPE_FLAG=}")

causals = CAUSAL

def get_platform():
"""
Expand Down Expand Up @@ -163,7 +193,7 @@ def validate_and_update_archs(archs):
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is not None:
if CUDA_HOME is not None and ENABLE_SM90:
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
Expand All @@ -173,101 +203,36 @@ def validate_and_update_archs(archs):
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if FORCE_CXX11_ABI:
torch._C._GLIBCXX_USE_CXX11_ABI = True

sources = ['csrc/flash_attn/flash_api.cpp']
for method, headdim, dtype, causal in itertools.product(methods, headdims, dtypes, causals):
assert method in METHOD, f"{method} not supported"
assert headdim in HEADDIM, f"{headdim} not supported"
assert dtype in DTYPE, f"{dtype} not supported"
assert causal in CAUSAL, f"{causal} not supported"
filename = ''
if causal:
filename = f"csrc/flash_attn/src/flash_{method}_hdim{headdim}_{dtype}_causal"+"_sm80.cu"
else:
filename = f"csrc/flash_attn/src/flash_{method}_hdim{headdim}_{dtype}"+"_sm80.cu"
sources.append(filename)
print("\n\nsources = {}\n\n".format(sources))

ext_modules.append(
CUDAExtension(
name="flash_attn_2_cuda",
sources=[
"csrc/flash_attn/flash_api.cpp",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
],
sources=sources,
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"cxx": ["-O3", "-std=c++17",
HEADDIM_FLAG,
DTYPE_FLAG,
],
"nvcc": append_nvcc_threads(
[
"-O3",
HEADDIM_FLAG,
DTYPE_FLAG,
"-DDTYPE_FP16",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
Expand Down