Skip to content

Commit

Permalink
Add Flash Attention support on ROCM (pytorch#121561)
Browse files Browse the repository at this point in the history
This patch addresses the major limitations in our previous [PR pytorch#115981](pytorch#115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in future release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to pytorch#112997

Pull Request resolved: pytorch#121561
Approved by: https://github.com/huydhn
  • Loading branch information
xinyazhang authored and pytorchmergebot committed Mar 28, 2024
1 parent 8d676a6 commit 12116ae
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 326 deletions.
23 changes: 18 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -742,13 +742,28 @@ if(MSVC)
append_cxx_flag_if_supported("/utf-8" CMAKE_CXX_FLAGS)
endif()

# CAVEAT: do NOT check USE_ROCM here, because USE_ROCM is always True until
# include(cmake/Dependencies.cmake)
# Note for ROCM platform:
# 1. USE_ROCM is always ON until include(cmake/Dependencies.cmake)
# 2. USE_CUDA will become OFF during re-configuration
# Truth Table:
# CUDA 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default
# CUDA 2nd pass: USE_CUDA=True;USE_ROCM=False, FLASH evaluates to ON by default
# ROCM 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default
# ROCM 2nd pass: USE_CUDA=False;USE_ROCM=True, FLASH evaluates to ON by default
# CPU 1st pass: USE_CUDA=False(Cmd Option);USE_ROCM=True, FLASH evaluates to OFF by default
# CPU 2nd pass: USE_CUDA=False(Cmd Option);USE_ROCM=False, FLASH evaluates to OFF by default
# Thus we cannot tell ROCM 2nd pass and CPU 1st pass
#
# The only solution is to include(cmake/Dependencies.cmake), and defer the
# aotriton build decision later.

include(cmake/Dependencies.cmake)

cmake_dependent_option(
USE_FLASH_ATTENTION
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform" ON
"USE_CUDA AND NOT MSVC" OFF)
"USE_CUDA OR USE_ROCM;NOT MSVC" OFF)

# We are currenlty not using alibi attention for Flash
# So we disable this feature by default
Expand All @@ -764,8 +779,6 @@ cmake_dependent_option(
Will be disabled if not supported by the platform" ON
"USE_CUDA" OFF)

include(cmake/Dependencies.cmake)

if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
Expand Down
36 changes: 13 additions & 23 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include <cmath>
#include <functional>

#if USE_ROCM
#include <aotriton/flash.h>
#endif

/**
* Note [SDPA Runtime Dispatch]
* SDPA relies on a runtime dispatch mechanism to select the appropriate
Expand Down Expand Up @@ -182,32 +186,18 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
// Check that the gpu is capable of running flash attention
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
static const char *over_arch = [] {
auto rc = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
if (rc) {
TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
"Later changes to this environment variable with os.environ "
"(or other methods) will not affect SDPA function's behavior.");
}
return rc;
}();
const char* real_arch = dprops->gcnArchName;
const char* arch = over_arch ? over_arch : real_arch;
if (mi200 != arch) {
if (debug) {
TORCH_WARN(
"Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ",
arch,
".",
over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "",
over_arch ? real_arch : "");
}
return false;
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (debug) {
TORCH_WARN(
"Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
}
return false;
}
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm90>(dprops)) {
if (debug) {
TORCH_WARN(
Expand Down
Loading

0 comments on commit 12116ae

Please sign in to comment.