Skip to content

Commit

Permalink
Re-add initial Flash Attention support on ROCM (pytorch#115981)
Browse files Browse the repository at this point in the history
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (pytorch#114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes pytorch#112997

Pull Request resolved: pytorch#115981
Approved by: https://github.com/malfet
  • Loading branch information
xinyazhang authored and pytorchmergebot committed Jan 4, 2024
1 parent 8195a0a commit e3ca734
Show file tree
Hide file tree
Showing 21 changed files with 885 additions and 38 deletions.
18 changes: 11 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -732,19 +732,23 @@ if(MSVC)
append_cxx_flag_if_supported("/utf-8" CMAKE_CXX_FLAGS)
endif()

include(cmake/Dependencies.cmake)

# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
# CAVEAT: do NOT check USE_ROCM here, because USE_ROCM is always True until
# include(cmake/Dependencies.cmake)
cmake_dependent_option(
USE_FLASH_ATTENTION
"Whether to build the flash_attention kernel for scaled dot product attention" ON
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform" ON
"USE_CUDA AND NOT MSVC" OFF)

# CAVEAT: Again, do not check USE_ROCM here
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
cmake_dependent_option(
USE_MEM_EFF_ATTENTION
"Enable memory-efficient attention for scaled dot product attention" ON
"USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
"Enable memory-efficient attention for scaled dot product attention.\
Will be disabled if not supported by the platform" ON
"USE_CUDA" OFF)

include(cmake/Dependencies.cmake)

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
Expand Down
37 changes: 34 additions & 3 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")

# flash_attention sources
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")

#Mem_eff attention sources
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
file(GLOB mem_eff_attention_cuda_kernels_cu "native/transformers/cuda/mem_eff_attention/kernels/*.cu")
Expand All @@ -175,6 +179,9 @@ if(USE_FLASH_ATTENTION)
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
list(APPEND FLASH_ATTENTION_CUDA_SOURCES ${flash_attention_cuda_cu} ${flash_attention_cuda_kernels_cu})
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu})

list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip})
list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip})
endif()

if(USE_MEM_EFF_ATTENTION)
Expand Down Expand Up @@ -284,10 +291,34 @@ endif()

if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip} ${native_transformers_hip_hip})
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
list(APPEND ATen_HIP_SRCS
${ATen_HIP_SRCS}
${hip_hip}
${native_hip_hip}
${native_nested_hip_hip}
${native_sparse_hip_hip}
${native_quantized_hip_hip}
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
)
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${native_transformers_hip_cpp} ${native_quantized_cudnn_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS})
set(all_hip_cpp ${native_miopen_cpp} ${native_cudnn_hip_cpp} ${miopen_cpp} ${all_hip_cpp})
list(APPEND all_hip_cpp
${native_nested_hip_cpp}
${native_sparse_hip_cpp}
${native_quantized_hip_cpp}
${native_transformers_hip_cpp}
${native_quantized_cudnn_hip_cpp}
${hip_cpp}
${native_hip_cpp}
${native_hip_linalg_cpp}
${cuda_generated_sources}
${ATen_HIP_SRCS}
${native_miopen_cpp}
${native_cudnn_hip_cpp}
${miopen_cpp}
${all_hip_cpp}
)
endif()

list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..)
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1103,9 +1103,9 @@ void int8_gemm(

// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
#if defined(USE_ROCM) && ROCM_VERSION < 50600
#define ROCM_CONST_BUG
#define ROCM_CONST_BUG_CAST(Type, Input) const_cast<Type>(reinterpret_cast<const Type>(Input))
#else
#define ROCM_CONST_BUG const
#define ROCM_CONST_BUG_CAST(Type, Input) reinterpret_cast<const Type>(Input)
#endif

template <>
Expand All @@ -1131,7 +1131,7 @@ void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>)) {
m,
n,
reinterpret_cast<const cuComplex*>(alpha),
reinterpret_cast<ROCM_CONST_BUG cuComplex*>(A),
ROCM_CONST_BUG_CAST(cuComplex*, A),
lda,
reinterpret_cast<cuComplex*>(B),
ldb));
Expand All @@ -1148,7 +1148,7 @@ void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>)) {
m,
n,
reinterpret_cast<const cuDoubleComplex*>(alpha),
reinterpret_cast<ROCM_CONST_BUG cuDoubleComplex*>(A),
ROCM_CONST_BUG_CAST(cuDoubleComplex*, A),
lda,
reinterpret_cast<cuDoubleComplex*>(B),
ldb));
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,13 @@ int64_t _fused_sdp_choice_meta(
bool is_causal,
c10::optional<double> scale) {
auto query_key_set = query_.key_set();
#if defined(USE_ROCM)
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
if (has_rocm) {
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale);
return choice_int;
}
#else
bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA);
if (has_cuda) {
auto choice_int = _fused_sdp_choice_stub(
Expand All @@ -458,6 +465,7 @@ int64_t _fused_sdp_choice_meta(
scale);
return choice_int;
}
#endif
return static_cast<int64_t>(sdp::SDPBackend::math);
}
namespace {
Expand Down Expand Up @@ -625,7 +633,8 @@ Tensor scaled_dot_product_attention(
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
if (query_.device().type() == DeviceType::CUDA
|| query_.device().type() == DeviceType::CPU){
|| query_.device().type() == DeviceType::CPU
|| query_.device().type() == DeviceType::HIP){
choice_int = _fused_sdp_choice_stub(query_.device().type(),
query_, key, value, attn_mask_, dropout_p, is_causal, scale);
}
Expand Down
27 changes: 27 additions & 0 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <c10/util/Exception.h>
#include <c10/util/env.h>
#include <c10/util/irange.h>
#include <c10/util/CallOnce.h>

#include <c10/core/SymInt.h>
#include <c10/util/string_view.h>
Expand Down Expand Up @@ -181,6 +182,31 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
static const char *over_arch = [] {
auto rc = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
if (rc) {
TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
"Later changes to this environment variable with os.environ "
"(or other methods) will not affect SDPA function's behavior.");
}
return rc;
}();
const char* real_arch = dprops->gcnArchName;
const char* arch = over_arch ? over_arch : real_arch;
if (mi200 != arch) {
if (debug) {
TORCH_WARN(
"Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ",
arch,
".",
over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "",
over_arch ? real_arch : "");
}
return false;
}
#else
if (!check_sm_version<sm80, sm90>(dprops)) {
if (debug) {
TORCH_WARN(
Expand All @@ -192,6 +218,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
return false;
}
#endif
return true;
}

Expand Down
Loading

0 comments on commit e3ca734

Please sign in to comment.