Skip to content

Commit 9f0b460

Browse files
mehdi-golit4c1aacostadiaz
authored
Fix the performance regression for flash attention for 2025.1 release compiler (#327)
This PR is a workaround for the performance regression in 2025.1 compiler. The following commit that fixes the issue( intel/llvm@71ca51f) did not meet the dpcpp 2025.1 cut-off date. This will be added in the 2025.2 release. --------- Co-authored-by: Tadej Ciglarič <[email protected]> Co-authored-by: Alejandro Acosta <[email protected]>
1 parent 6ee9439 commit 9f0b460

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ using namespace cute;
4747

4848
template <typename To_type, typename Engine, typename Layout>
4949
CUTLASS_DEVICE auto convert_type(Tensor<Engine, Layout> const &tensor) {
50-
using From_type = typename Engine::value_type;
51-
constexpr int numel = decltype(size(tensor))::value;
52-
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
53-
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
54-
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
50+
using From_type = typename Engine::value_type;
51+
constexpr int numel = decltype(size(tensor))::value;
52+
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
53+
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
54+
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
5555
}
5656

5757
////////////////////////////////////////////////////////////////////////////////////////////////////

include/cutlass/numeric_conversion.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,15 @@ struct NumericConverter<cutlass::bfloat16_t, float, FloatRoundStyle::round_to_ne
539539

540540
CUTLASS_HOST_DEVICE
541541
static result_type convert(source_type const & s) {
542+
#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER < 20250200) && defined(__SYCL_DEVICE_ONLY__)
543+
// Temporary patch to avoid linking in the devicelib fallback unconditionally.
544+
// This is the work around to fix performance regression in 2025.1
545+
result_type res;
546+
res.storage=(__spirv_ConvertFToBF16INTEL(s));
547+
return res;
548+
#else
542549
return static_cast<cutlass::bfloat16_t>(s);
550+
#endif
543551
}
544552

545553
CUTLASS_HOST_DEVICE

0 commit comments

Comments
 (0)