Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.7.3 build failed in NGC pytorch:24.12-py3 #1452

Closed
xuchunmei000 opened this issue Jan 21, 2025 · 6 comments
Closed

v2.7.3 build failed in NGC pytorch:24.12-py3 #1452

xuchunmei000 opened this issue Jan 21, 2025 · 6 comments

Comments

@xuchunmei000
Copy link

Next is the steps to reproduce:

docker pull nvcr.io/nvidia/pytorch:24.12-py3
docker run -ti --gpus all nvcr.io/nvidia/pytorch:24.12-py3
git clone https://github.com/Dao-AILab/flash-attention.git
git checkout v2.7.3
cd hopper
MAX_JOBS=8 python3 setup.py bdist_wheel

The full log can be found in attachment (hopper-build.log).
error message like:

/workspace/flash-attention/hopper/instantiations/flash_bwd_hdim128_bf16_sm90.cu:10:54:   required from here
/usr/local/cuda/include/cuda_runtime.h:287:28: error: call of overloaded ‘forward<const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, true>::Params&>(const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, true>::Params&)’ is ambiguous
  287 |     }(std::forward<ActTypes>(args)...);
      |     ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~
/usr/include/c++/13/bits/move.h:70:1: note: candidate: ‘constexpr _Tp&& std::forward(typename remove_reference<_Functor>::type&) [with _Tp = const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, true>::Params&; typename remove_reference<_Functor>::type = const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, true>::Params&]’
   70 |     forward(typename std::remove_reference<_Tp>::type& __t) noexcept
      | ^   ~~~
/usr/include/c++/13/bits/move.h:82:1: note: candidate: ‘constexpr _Tp&& std::forward(typename remove_reference<_Functor>::type&&) [with _Tp = const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, true>::Params&; typename remove_reference<_Functor>::type = const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, true>::Params&]’
   82 |     forward(typename std::remove_reference<_Tp>::type&& __t) noexcept
      | ^   ~~~
/usr/local/cuda/include/cuda_runtime.h: In instantiation of ‘cudaError_t cudaLaunchKernelEx(const cudaLaunchConfig_t*, void (*)(ExpTypes ...), ActTypes&& ...) [with ExpTypes = {flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, false>::Params}; ActTypes = {const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, false>::Params&}; cudaError_t = cudaError; cudaLaunchConfig_t = cudaLaunchConfig_st]’:
/workspace/flash-attention/csrc/cutlass/include/cutlass/kernel_launch.h:116:47:   required from ‘cutlass::Status cutlass::kernel_launch(dim3, dim3, size_t, cudaStream_t, const Params&, bool) [with GemmKernel = flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, bfloat16_t, float, arch::Sm90, true, false>; Params = flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, bfloat16_t, float, arch::Sm90, true, false>::Params; dim3 = dim3; size_t = long unsigned int; cudaStream_t = CUstream_st*]’
/workspace/flash-attention/hopper/flash_bwd_launch_template.h:74:224:   required from ‘void run_flash_bwd(Flash_bwd_params&, cudaStream_t) [with int Arch = 90; int kHeadDim = 128; int kBlockM = 64; int kBlockN = 128; Element = cutlass::bfloat16_t; bool Is_causal = true; bool Is_local = false; bool Has_softcap = false; bool Varlen = false; bool Deterministic = false; bool GQA = true; int Stages_dO = 2; int Stages_dS_or_QSm80 = 2; bool SdP_swapAB = true; bool dKV_swapAB = false; bool dQ_swapAB = false; int NumMmaWarpGroups = 2; int AtomLayoutMSdP = 1; int AtomLayoutNdKV = 2; int AtomLayoutMdQ = 1; bool V_in_regs = false; cudaStream_t = CUstream_st*]’
/workspace/flash-attention/hopper/flash_bwd_launch_template.h:289:1243:   required from ‘void run_mha_bwd_dispatch(Flash_bwd_params&, cudaStream_t) [with int Arch = 90; T = cutlass::bfloat16_t; int kBlockM = 64; int kBlockN = 128; int kHeadDim = 128; bool Is_causal = true; bool Is_local = false; bool Has_softcap = false; int Stages_dO = 2; int Stages_dS_or_QSm80 = 2; bool SdP_swapAB = true; bool dKV_swapAB = false; bool dQ_swapAB = false; int NumMmaWarpGroups = 2; int AtomLayoutMSdP = 1; int AtomLayoutNdKV = 2; int AtomLayoutMdQ = 1; bool V_in_regs = false; cudaStream_t = CUstream_st*]’
/workspace/flash-attention/hopper/flash_bwd_launch_template.h:337:332:   required from ‘void run_mha_bwd_hdim128(Flash_bwd_params&, cudaStream_t) [with int Arch = 90; T = cutlass::bfloat16_t; bool Has_softcap = false; cudaStream_t = CUstream_st*]’
/workspace/flash-attention/hopper/instantiations/flash_bwd_hdim128_bf16_sm90.cu:10:54:   required from here
/usr/local/cuda/include/cuda_runtime.h:287:28: error: call of overloaded ‘forward<const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, false>::Params&>(const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, false>::Params&)’ is ambiguous
  287 |     }(std::forward<ActTypes>(args)...);
      |     ~~~~~~~~~~~~~~~~~~~~~~~^~~~~~
/usr/include/c++/13/bits/move.h:70:1: note: candidate: ‘constexpr _Tp&& std::forward(typename remove_reference<_Functor>::type&) [with _Tp = const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, false>::Params&; typename remove_reference<_Functor>::type = const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, false>::Params&]’
   70 |     forward(typename std::remove_reference<_Tp>::type& __t) noexcept
      | ^   ~~~
/usr/include/c++/13/bits/move.h:82:1: note: candidate: ‘constexpr _Tp&& std::forward(typename remove_reference<_Functor>::type&&) [with _Tp = const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, false>::Params&; typename remove_reference<_Functor>::type = const flash::FlashAttnBwdPreprocess<cute::tuple<cute::C<64>, cute::C<128> >, cutlass::bfloat16_t, float, cutlass::arch::Sm90, true, false>::Params&]’
   82 |     forward(typename std::remove_reference<_Tp>::type&& __t) noexcept
      | ^   ~~~

while v2.7.2 can be compiled successfully.

hopper-build.log

@cameronshinn
Copy link
Contributor

cameronshinn commented Jan 24, 2025

I am also getting smilar errors building FA3. Here's my output as well if that helps: compile_out.txt. I am going to try switching to pytorch:24.02-py3 since that's the latest release using nvcc 12.3, which they want us to build with after the refactor.

edit: downgrading containers fixed my issues.

@tridao
Copy link
Member

tridao commented Jan 26, 2025

Yup gcc 13 in nvcr 24.12 (and 24.11) is not compatible with nvcc 12.3. You can switch to gcc 12 if you want to use those containers. Longer term we'll figure out how to get the best perf on nvcc 12.6 instead of having to pin nvcc 12.3.

apt-get install gcc-12 g++-12
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12 \
    && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12

@tridao
Copy link
Member

tridao commented Jan 29, 2025

We're getting good perf with nvcc 12.8 on the nvcr 25.01 image so I recommend using that docker image.
[Make sure you pull the latest commit from this repo]

@cameronshinn
Copy link
Contributor

We're getting good perf with nvcc 12.8 on the nvcr 25.01 image so I recommend using that docker image. [Make sure you pull the latest commit from this repo]

thanks for the quick update! will give that a go

@xuchunmei000
Copy link
Author

We're getting good perf with nvcc 12.8 on the nvcr 25.01 image so I recommend using that docker image. [Make sure you pull the latest commit from this repo]

Thanks for your quick replay, and sorry for late to verify.
use nvcr 25.01, complile successfully based on commit(02541ac)

@WissamAntoun
Copy link

The following Dockerfile isn't working:

FROM nvcr.io/nvidia/pytorch:25.01-py3

WORKDIR /build

RUN apt-get update && apt-get install -y \
    git \
    htop

RUN git clone https://github.com/Dao-AILab/flash-attention.git \
    && cd flash-attention/hopper \
    && git checkout -b 02541ac \
    && MAX_JOBS=8 python3 setup.py install

here's the build log

build.log

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants