diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index ba674beb7..b747c8d60 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -7,7 +7,8 @@ #include #include #include -#include +#include // For at::Generator and at::PhiloxCudaState +#include // For at::cuda::philox::unpack #include diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 8838f59b6..9a503998c 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -7,7 +7,7 @@ #include #include -#include // For at::cuda::philox::unpack +#include // For at::Generator and at::PhiloxCudaState constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 788f3790e..30217f3f6 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -4,6 +4,8 @@ #pragma once +#include // For at::cuda::philox::unpack + #include #include diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 34fdfef70..e3c0e8e04 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.1.post2" +__version__ = "2.7.1.post3" from flash_attn.flash_attn_interface import ( flash_attn_func,