From e782d28692e7b6ce6cd8a5f095b211b7510fef9e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 6 Dec 2024 21:39:38 -0800 Subject: [PATCH] [CI] Change torch #include to make it work with torch 2.1 Philox --- csrc/flash_attn/flash_api.cpp | 3 ++- csrc/flash_attn/src/flash.h | 2 +- csrc/flash_attn/src/flash_fwd_kernel.h | 2 ++ flash_attn/__init__.py | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index ba674beb7..b747c8d60 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -7,7 +7,8 @@ #include #include #include -#include +#include // For at::Generator and at::PhiloxCudaState +#include // For at::cuda::philox::unpack #include diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 8838f59b6..9a503998c 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -7,7 +7,7 @@ #include #include -#include // For at::cuda::philox::unpack +#include // For at::Generator and at::PhiloxCudaState constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 788f3790e..30217f3f6 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -4,6 +4,8 @@ #pragma once +#include // For at::cuda::philox::unpack + #include #include diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 34fdfef70..e3c0e8e04 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.1.post2" +__version__ = "2.7.1.post3" from flash_attn.flash_attn_interface import ( flash_attn_func,