Skip to content

Commit 1b62e2f

Browse files
authored
Add base framework for FlashAttention unit tests (#295)
This PR adds the foundational framework for unit testing FlashAttention. While it doesn't cover all test scenarios it provides a starting point for validating core functionality and ensuring correctness.
1 parent 34cf56d commit 1b62e2f

File tree

11 files changed

+1084
-4
lines changed

11 files changed

+1084
-4
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
#pragma once
32+
33+
#include "cutlass/cutlass.h"
34+
#include <cute/tensor.hpp>
35+
36+
namespace cutlass::fmha::collective {
37+
38+
using namespace cute;
39+
40+
struct VariableLength {
41+
int max_length;
42+
int* cumulative_length = nullptr;
43+
44+
CUTE_HOST_DEVICE operator int() const {
45+
return max_length;
46+
}
47+
};
48+
49+
template<class T> struct is_variable_length : std::false_type {};
50+
template<> struct is_variable_length<VariableLength> : std::true_type {};
51+
template<class T> constexpr bool is_variable_length_v = is_variable_length<T>::value;
52+
53+
template<class Shape, class Idx>
54+
CUTE_HOST_DEVICE
55+
constexpr auto
56+
apply_variable_length(Shape const& shape, Idx const& idx) {
57+
return transform_leaf(shape, [&](auto const& s) {
58+
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
59+
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
60+
}
61+
else {
62+
return s;
63+
}
64+
});
65+
}
66+
67+
template<class Shape, class Coord, class Idx>
68+
CUTE_HOST_DEVICE
69+
constexpr auto
70+
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
71+
auto new_shape = apply_variable_length(shape, idx);
72+
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
73+
if constexpr (is_variable_length_v<remove_cvref_t<decltype(s)>>) {
74+
return cute::make_tuple(c, s.cumulative_length[idx]);
75+
}
76+
else {
77+
return c;
78+
}
79+
});
80+
return cute::make_tuple(new_shape, new_coord);
81+
}
82+
83+
} // namespace cutlass::fmha::collective
84+
85+
namespace cute {
86+
87+
template<>
88+
struct is_integral<cutlass::fmha::collective::VariableLength> : true_type {};
89+
90+
CUTE_HOST_DEVICE
91+
void print(cutlass::fmha::collective::VariableLength a) {
92+
printf("Varlen<%d, %p>", a.max_length, a.cumulative_length);
93+
}
94+
95+
}

applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
#include "cute/algorithm/functional.hpp"
3737
#include "cute/atom/mma_atom.hpp"
3838
#include "cute/algorithm/gemm.hpp"
39-
#include "cute/tensor_predicate.hpp"
39+
#include "cutlass/util/packed_stride.hpp"
40+
41+
#include "fmha_fusion.hpp"
4042

4143
/////////////////////////////////////////////////////////////////////////////////////////////////
4244

applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "cutlass/kernel_hardware_info.hpp"
3737

3838
#include "flash_attention_v2/collective/xe_flash_attn_mma.hpp"
39+
#include "flash_attention_v2/kernel/tile_scheduler.hpp"
3940

4041
namespace cutlass::flash_attention::kernel {
4142

benchmarks/pvc/flash_attention_v2/benchmark_runner.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
#include "cutlass/epilogue/collective/default_epilogue.hpp"
3434
#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
35-
#include "77_blackwell_fmha/collective/fmha_fusion.hpp"
3635
#include "flash_attention_v2/kernel/tile_scheduler.hpp"
3736
#include "cutlass/gemm/device/gemm_universal_adapter.h"
3837
#include "cutlass/util/packed_stride.hpp"

examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
#include "cutlass/epilogue/collective/default_epilogue.hpp"
3434
#include "cutlass/gemm/device/gemm_universal_adapter.h"
35-
#include "77_blackwell_fmha/collective/fmha_fusion.hpp"
35+
#include "flash_attention_v2/collective/fmha_fusion.hpp"
3636
#include "flash_attention_v2/kernel/tile_scheduler.hpp"
3737
#include "cutlass/util/packed_stride.hpp"
3838
#include "flash_attention_v2/kernel/xe_flash_attn_gemm.hpp"

test/unit/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ function(cutlass_test_unit_add_executable NAME)
8585
PRIVATE
8686
${CUTLASS_UNIT_TEST_COMMON_DIR}
8787
${__EXTRA_INCLUDE_DIRS}
88+
${CUTLASS_APPLICATIONS_DIR}
8889
)
8990
if (__WITHOUT_CUDA)
9091
# Avoid CUDA dependencies for host-only unit tests that provide the
@@ -108,6 +109,7 @@ function(cutlass_test_unit_add_executable NAME)
108109
endif()
109110

110111
if(CUTLASS_ENABLE_SYCL)
112+
add_onemkl_to_target(TARGET ${NAME})
111113
add_sycl_to_target(TARGET ${NAME})
112114
endif()
113115

@@ -188,6 +190,7 @@ if (CUTLASS_ENABLE_SYCL)
188190
set(SUBDIRS
189191
cute
190192
gemm
193+
flash_attention
191194
)
192195
else()
193196
set(SUBDIRS
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# 1. Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# 2. Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# 3. Neither the name of the copyright holder nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
if(SYCL_INTEL_TARGET)
30+
31+
set(CUTLASS_APPLICATIONS_DIR ${CMAKE_SOURCE_DIR}/applications)
32+
33+
cutlass_test_unit_add_executable(
34+
cutlass_test_unit_flash_attention_xe
35+
xe_flash_attention_bf16.cpp
36+
xe_flash_attention_fp16.cpp
37+
)
38+
39+
add_custom_target(
40+
cutlass_test_unit_flash_attention
41+
DEPENDS
42+
cutlass_test_unit_flash_attention_xe
43+
)
44+
45+
add_custom_target(
46+
test_unit_flash_attention
47+
DEPENDS
48+
test_unit_flash_attention_xe
49+
)
50+
51+
else()
52+
# Dummy targets if not building for Intel
53+
add_custom_target(cutlass_test_unit_flash_attention)
54+
add_custom_target(test_unit_flash_attention)
55+
56+
endif()

0 commit comments

Comments
 (0)