Skip to content

Commit f57f308

Browse files
committed
spirv APIs
1 parent a107db3 commit f57f308

File tree

9 files changed

+1541
-862
lines changed

9 files changed

+1541
-862
lines changed

cmake/FindDPCPP.cmake

+5-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ endif()
6565
if("${DPCPP_SYCL_TARGET}" STREQUAL "intel_gpu_pvc" OR
6666
"${DPCPP_SYCL_TARGET}" STREQUAL "spir64" OR
6767
"${DPCPP_SYCL_TARGET}" STREQUAL "intel_gpu_bmg_g21")
68-
list(APPEND DPCPP_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier")
68+
if (CMAKE_CXX_COMPILER_ID MATCHES "IntelLLVM" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 2025.2)
69+
list(APPEND DPCPP_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier")
70+
else()
71+
list(APPEND DPCPP_FLAGS "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate")
72+
endif()
6973
endif()
7074

7175
if(UNIX)

include/cute/arch/copy_xe.hpp

+1-6
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@
3333
#include <cute/arch/xe_copy_2B.hpp>
3434
#include <cute/arch/xe_copy_4B.hpp>
3535
#include <cute/arch/xe_copy_8B.hpp>
36-
#ifdef __SYCL_DEVICE_ONLY__
37-
#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x
38-
#else
39-
#define SYCL_DEVICE_BUILTIN(x) inline x { assert(false); }
40-
#endif
4136

4237
// prefetch
4338
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uchar(
@@ -70,7 +65,7 @@ SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong4(
7065
SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong8(
7166
const __attribute__((opencl_global)) uint64_t *base, int immElemOff,
7267
enum CacheControl cacheOpt));
73-
#undef SYCL_DEVICE_BUILTIN
68+
7469

7570
#ifdef __SYCL_DEVICE_ONLY__
7671
SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics);

include/cute/arch/mma_xe.hpp

+160-48
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,7 @@
3232

3333
#include <cute/config.hpp>
3434
#include <cute/arch/mma.hpp>
35-
#include <cute/util/sycl_vec.hpp>
36-
37-
#ifdef __SYCL_DEVICE_ONLY__
38-
#define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x
39-
#else
40-
#define SYCL_DEVICE_OCL(x) inline x { CUTE_INVALID_CONTROL_PATH("Trying to use XE built-in on non-XE hardware"); }
41-
#endif
35+
#include <cute/arch/xe_config.hpp>
4236

4337
// mma_bf16
4438
SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short8 a, cute::intel::int8 b, cute::intel::float8 acc));
@@ -66,7 +60,125 @@ SYCL_DEVICE_OCL(cute::intel::float4 intel_sub_group_tf32_tf32_matrix_mad_k8(cute
6660
SYCL_DEVICE_OCL(cute::intel::float2 intel_sub_group_tf32_tf32_matrix_mad_k8(float a, cute::intel::float8 b, cute::intel::float2 acc));
6761
SYCL_DEVICE_OCL(float intel_sub_group_tf32_tf32_matrix_mad_k8(float a, cute::intel::float8 b, float acc));
6862

69-
#undef SYCL_DEVICE_OCL
63+
#if defined(CUTE_ARCH_MMA_XE_SPIRV_ENABLED)
64+
namespace cute::detail
65+
{
66+
template<>
67+
struct XeSubgroupMatrixMultiplyAccumulate<float, bfloat16_t, bfloat16_t, float> {
68+
template<typename ARegisters, typename BRegisters, typename CRegisters>
69+
CUTE_HOST_DEVICE
70+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
71+
#ifdef __SYCL_DEVICE_ONLY__
72+
return __spirv_SubgroupMatrixMultiplyAccumulateINTEL(16, a, b, c, SPIRV_MMAOperands::SPIRV_MatrixABf16 | SPIRV_MMAOperands::SPIRV_MatrixBBf16 );
73+
#endif
74+
}
75+
};
76+
77+
template<>
78+
struct XeSubgroupMatrixMultiplyAccumulate<float, half_t, half_t, float> {
79+
template<typename ARegisters, typename BRegisters, typename CRegisters>
80+
CUTE_HOST_DEVICE
81+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
82+
#ifdef __SYCL_DEVICE_ONLY__
83+
return __spirv_SubgroupMatrixMultiplyAccumulateINTEL(16, a, b, c, SPIRV_MMAOperands::SPIRV_MatrixAFp16 | SPIRV_MMAOperands::SPIRV_MatrixBFp16);
84+
#endif
85+
}
86+
};
87+
88+
template<>
89+
struct XeSubgroupMatrixMultiplyAccumulate<int32_t, int8_t, int8_t, int32_t> {
90+
template<typename ARegisters, typename BRegisters, typename CRegisters>
91+
CUTE_HOST_DEVICE
92+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
93+
#ifdef __SYCL_DEVICE_ONLY__
94+
return __spirv_SubgroupMatrixMultiplyAccumulateINTEL(32, a, b, c, SPIRV_MMAOperands::SPIRV_MatrixASigned | SPIRV_MMAOperands::SPIRV_MatrixBSigned | SPIRV_MMAOperands::SPIRV_MatrixAInt8 | SPIRV_MMAOperands::SPIRV_MatrixBInt8);
95+
#endif
96+
}
97+
};
98+
99+
template<>
100+
struct XeSubgroupMatrixMultiplyAccumulate<int32_t, uint8_t, uint8_t, int32_t> {
101+
template<typename ARegisters, typename BRegisters, typename CRegisters>
102+
CUTE_HOST_DEVICE
103+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
104+
#ifdef __SYCL_DEVICE_ONLY__
105+
return __spirv_SubgroupMatrixMultiplyAccumulateINTEL(32, a, b, c, SPIRV_MMAOperands::SPIRV_MatrixAInt8 | SPIRV_MMAOperands::SPIRV_MatrixBInt8);
106+
#endif
107+
}
108+
};
109+
110+
template<>
111+
struct XeSubgroupMatrixMultiplyAccumulate<float, tfloat32_t, tfloat32_t, float> {
112+
template<typename ARegisters, typename BRegisters, typename CRegisters>
113+
CUTE_HOST_DEVICE
114+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
115+
#ifdef __SYCL_DEVICE_ONLY__
116+
return __spirv_SubgroupMatrixMultiplyAccumulateINTEL(8, a, b, c, SPIRV_MMAOperands::SPIRV_MatrixATf32 | SPIRV_MMAOperands::SPIRV_MatrixBTf32);
117+
#endif
118+
}
119+
};
120+
} // namespace cute::detail end
121+
#endif
122+
123+
#if defined(CUTE_ARCH_MMA_XE_BUILTIN_ENABLED)
124+
namespace cute::detail
125+
{
126+
template<>
127+
struct XeSubgroupMatrixMultiplyAccumulate<float, bfloat16_t, bfloat16_t, float> {
128+
template<typename ARegisters, typename BRegisters, typename CRegisters>
129+
CUTE_HOST_DEVICE
130+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
131+
#ifdef __SYCL_DEVICE_ONLY__
132+
return intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c);
133+
#endif
134+
}
135+
};
136+
137+
template<>
138+
struct XeSubgroupMatrixMultiplyAccumulate<float, half_t, half_t, float> {
139+
template<typename ARegisters, typename BRegisters, typename CRegisters>
140+
CUTE_HOST_DEVICE
141+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
142+
#ifdef __SYCL_DEVICE_ONLY__
143+
return intel_sub_group_f16_f16_matrix_mad_k16(a, b, c);
144+
#endif
145+
}
146+
};
147+
148+
template<>
149+
struct XeSubgroupMatrixMultiplyAccumulate<int32_t, int8_t, int8_t, int32_t> {
150+
template<typename ARegisters, typename BRegisters, typename CRegisters>
151+
CUTE_HOST_DEVICE
152+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
153+
#ifdef __SYCL_DEVICE_ONLY__
154+
return intel_sub_group_i8_i8_matrix_mad_k32(a, b, c);
155+
#endif
156+
}
157+
};
158+
159+
template<>
160+
struct XeSubgroupMatrixMultiplyAccumulate<int32_t, uint8_t, uint8_t, int32_t> {
161+
template<typename ARegisters, typename BRegisters, typename CRegisters>
162+
CUTE_HOST_DEVICE
163+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
164+
#ifdef __SYCL_DEVICE_ONLY__
165+
return intel_sub_group_u8_u8_matrix_mad_k32(a, b, c);
166+
#endif
167+
}
168+
};
169+
170+
template<>
171+
struct XeSubgroupMatrixMultiplyAccumulate<float, tfloat32_t, tfloat32_t, float> {
172+
template<typename ARegisters, typename BRegisters, typename CRegisters>
173+
CUTE_HOST_DEVICE
174+
auto operator()(ARegisters a, BRegisters b, CRegisters c) {
175+
#ifdef __SYCL_DEVICE_ONLY__
176+
return intel_sub_group_tf32_tf32_matrix_mad_k8(a, b, c);
177+
#endif
178+
}
179+
};
180+
} // namespace cute::detail end
181+
#endif
70182

71183
namespace cute {
72184
//MxNxK_D,A,B,C
@@ -86,8 +198,8 @@ struct XE_8x16x16_F32BF16BF16F32_TT
86198
intel::int8 const& b,
87199
intel::float8 const& c)
88200
{
89-
#if defined(SYCL_INTEL_TARGET)
90-
d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c);
201+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
202+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, bfloat16_t, bfloat16_t, float>{}(a, b, c);
91203
#else
92204
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_F32BF16BF16F32_TT on non-PVC hardware");
93205
#endif
@@ -106,8 +218,8 @@ struct XE_4x16x16_F32BF16BF16F32_TT
106218
intel::int8 const& b,
107219
intel::float4 const& c)
108220
{
109-
#if defined(SYCL_INTEL_TARGET)
110-
d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c);
221+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
222+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, bfloat16_t, bfloat16_t, float>{}(a, b, c);
111223
#else
112224
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_F32BF16BF16F32_TT on non-PVC hardware");
113225
#endif
@@ -126,8 +238,8 @@ struct XE_2x16x16_F32BF16BF16F32_TT
126238
intel::int8 const& b,
127239
intel::float2 const& c)
128240
{
129-
#if defined(SYCL_INTEL_TARGET)
130-
d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c);
241+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
242+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, bfloat16_t, bfloat16_t, float>{}(a, b, c);
131243
#else
132244
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_F32BF16BF16F32_TT on non-PVC hardware");
133245
#endif
@@ -147,8 +259,8 @@ struct XE_1x16x16_F32BF16BF16F32_TT
147259
intel::int8 const& b,
148260
float const& c)
149261
{
150-
#if defined(SYCL_INTEL_TARGET)
151-
d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c);
262+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
263+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, bfloat16_t, bfloat16_t, float>{}(a, b, c);
152264
#else
153265
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x16_F32BF16BF16F32_TT on non-PVC hardware");
154266
#endif
@@ -172,8 +284,8 @@ struct XE_8x16x16_F32F16F16F32_TT
172284
intel::int8 const& b,
173285
intel::float8 const& c)
174286
{
175-
#if defined(SYCL_INTEL_TARGET)
176-
d = intel_sub_group_f16_f16_matrix_mad_k16(a, b, c);
287+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
288+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, half_t, half_t, float>{}(a, b, c);
177289
#else
178290
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_F32F16F16F32_TT on non-PVC hardware");
179291
#endif
@@ -193,8 +305,8 @@ struct XE_4x16x16_F32F16F16F32_TT
193305
intel::int8 const& b,
194306
intel::float4 const& c)
195307
{
196-
#if defined(SYCL_INTEL_TARGET)
197-
d = intel_sub_group_f16_f16_matrix_mad_k16(a, b, c);
308+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
309+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, half_t, half_t, float>{}(a, b, c);
198310
#else
199311
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_4x16x16_F32F16F16F32_TT on non-PVC hardware");
200312
#endif
@@ -214,8 +326,8 @@ struct XE_2x16x16_F32F16F16F32_TT
214326
intel::int8 const& b,
215327
intel::float2 const& c)
216328
{
217-
#if defined(SYCL_INTEL_TARGET)
218-
d = intel_sub_group_f16_f16_matrix_mad_k16(a, b, c);
329+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
330+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, half_t, half_t, float>{}(a, b, c);
219331
#else
220332
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_2x16x16_F32F16F16F32_TT on non-PVC hardware");
221333
#endif
@@ -235,8 +347,8 @@ struct XE_1x16x16_F32F16F16F32_TT
235347
intel::int8 const& b,
236348
float const& c)
237349
{
238-
#if defined(SYCL_INTEL_TARGET)
239-
d = intel_sub_group_f16_f16_matrix_mad_k16(a, b, c);
350+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
351+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, half_t, half_t, float>{}(a, b, c);
240352
#else
241353
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x16_F32F16F16F32_TT on non-PVC hardware");
242354
#endif
@@ -260,8 +372,8 @@ struct XE_8x16x32_S32S8S8S32_TT
260372
intel::int8 const& b,
261373
intel::int8 const& c)
262374
{
263-
#if defined(SYCL_INTEL_TARGET)
264-
d = intel_sub_group_i8_i8_matrix_mad_k32(a, b, c);
375+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
376+
d = detail::XeSubgroupMatrixMultiplyAccumulate<int32_t, int8_t, int8_t, int32_t>{}(a, b, c);
265377
#else
266378
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x32_S32S8S8S32_TT on non-PVC hardware");
267379
#endif
@@ -281,8 +393,8 @@ struct XE_4x16x32_S32S8S8S32_TT
281393
intel::int8 const& b,
282394
intel::int4 const& c)
283395
{
284-
#if defined(SYCL_INTEL_TARGET)
285-
d = intel_sub_group_i8_i8_matrix_mad_k32(a, b, c);
396+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
397+
d = detail::XeSubgroupMatrixMultiplyAccumulate<int32_t, int8_t, int8_t, int32_t>{}(a, b, c);
286398
#else
287399
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_4x16x32_S32S8S8S32_TT on non-PVC hardware");
288400
#endif
@@ -302,8 +414,8 @@ struct XE_2x16x32_S32S8S8S32_TT
302414
intel::int8 const& b,
303415
intel::int2 const& c)
304416
{
305-
#if defined(SYCL_INTEL_TARGET)
306-
d = intel_sub_group_i8_i8_matrix_mad_k32(a, b, c);
417+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
418+
d = detail::XeSubgroupMatrixMultiplyAccumulate<int32_t, int8_t, int8_t, int32_t>{}(a, b, c);
307419
#else
308420
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_2x16x32_S32S8S8S32_TT on non-PVC hardware");
309421
#endif
@@ -323,8 +435,8 @@ struct XE_1x16x32_S32S8S8S32_TT
323435
intel::int8 const& b,
324436
int const& c)
325437
{
326-
#if defined(SYCL_INTEL_TARGET)
327-
d = intel_sub_group_i8_i8_matrix_mad_k32(a, b, c);
438+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
439+
d = detail::XeSubgroupMatrixMultiplyAccumulate<int32_t, int8_t, int8_t, int32_t>{}(a, b, c);
328440
#else
329441
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x32_S32S8S8S32_TT on non-PVC hardware");
330442
#endif
@@ -344,8 +456,8 @@ struct XE_8x16x32_S32U8U8S32_TT
344456
intel::uint8 const& b,
345457
intel::int8 const& c)
346458
{
347-
#if defined(SYCL_INTEL_TARGET)
348-
d = intel_sub_group_u8_u8_matrix_mad_k32(a, b, c);
459+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
460+
d = detail::XeSubgroupMatrixMultiplyAccumulate<int32_t, uint8_t, uint8_t, int32_t>{}(a, b, c);
349461
#else
350462
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x32_S32U8U8S32_TT on non-PVC hardware");
351463
#endif
@@ -365,8 +477,8 @@ struct XE_4x16x32_S32U8U8S32_TT
365477
intel::uint8 const& b,
366478
intel::int4 const& c)
367479
{
368-
#if defined(SYCL_INTEL_TARGET)
369-
d = intel_sub_group_u8_u8_matrix_mad_k32(a, b, c);
480+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
481+
d = detail::XeSubgroupMatrixMultiplyAccumulate<int32_t, uint8_t, uint8_t, int32_t>{}(a, b, c);
370482
#else
371483
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_4x16x32_S32U8U8S32_TT on non-PVC hardware");
372484
#endif
@@ -386,8 +498,8 @@ struct XE_2x16x32_S32U8U8S32_TT
386498
intel::uint8 const& b,
387499
intel::int2 const& c)
388500
{
389-
#if defined(SYCL_INTEL_TARGET)
390-
d = intel_sub_group_u8_u8_matrix_mad_k32(a, b, c);
501+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
502+
d = detail::XeSubgroupMatrixMultiplyAccumulate<int32_t, uint8_t, uint8_t, int32_t>{}(a, b, c);
391503
#else
392504
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_2x16x32_S32U8U8S32_TT on non-PVC hardware");
393505
#endif
@@ -407,8 +519,8 @@ struct XE_1x16x32_S32U8U8S32_TT
407519
intel::uint8 const& b,
408520
int const& c)
409521
{
410-
#if defined(SYCL_INTEL_TARGET)
411-
d = intel_sub_group_u8_u8_matrix_mad_k32(a, b, c);
522+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
523+
d = detail::XeSubgroupMatrixMultiplyAccumulate<int32_t, uint8_t, uint8_t, int32_t>{}(a, b, c);
412524
#else
413525
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x32_S32U8U8S32_TT on non-PVC hardware");
414526
#endif
@@ -428,8 +540,8 @@ struct XE_8x16x8_F32TF32TF32F32_TT
428540
intel::float8 const& b,
429541
intel::float8 const& c)
430542
{
431-
#if defined(SYCL_INTEL_TARGET)
432-
d = intel_sub_group_tf32_tf32_matrix_mad_k8(a, b, c);
543+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
544+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, tfloat32_t, tfloat32_t, float>{}(a, b, c);
433545
#else
434546
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x8_F32TF32TF32F32_TT on non-PVC hardware");
435547
#endif
@@ -449,8 +561,8 @@ struct XE_4x16x8_F32TF32TF32F32_TT
449561
intel::float8 const& b,
450562
intel::float4 const& c)
451563
{
452-
#if defined(SYCL_INTEL_TARGET)
453-
d = intel_sub_group_tf32_tf32_matrix_mad_k8(a, b, c);
564+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
565+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, tfloat32_t, tfloat32_t, float>{}(a, b, c);
454566
#else
455567
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_4x16x8_F32TF32TF32F32_TT on non-PVC hardware");
456568
#endif
@@ -470,8 +582,8 @@ struct XE_2x16x8_F32TF32TF32F32_TT
470582
intel::float8 const& b,
471583
intel::float2 const& c)
472584
{
473-
#if defined(SYCL_INTEL_TARGET)
474-
d = intel_sub_group_tf32_tf32_matrix_mad_k8(a, b, c);
585+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
586+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, tfloat32_t, tfloat32_t, float>{}(a, b, c);
475587
#else
476588
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_2x16x8_F32TF32TF32F32_TT on non-PVC hardware");
477589
#endif
@@ -491,8 +603,8 @@ struct XE_1x16x8_F32TF32TF32F32_TT
491603
intel::float8 const& b,
492604
float const& c)
493605
{
494-
#if defined(SYCL_INTEL_TARGET)
495-
d = intel_sub_group_tf32_tf32_matrix_mad_k8(a, b, c);
606+
#if defined(CUTE_ARCH_MMA_XE_ENABLED)
607+
d = detail::XeSubgroupMatrixMultiplyAccumulate<float, tfloat32_t, tfloat32_t, float>{}(a, b, c);
496608
#else
497609
CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x8_F32TF32TF32F32_TT on non-PVC hardware");
498610
#endif

0 commit comments

Comments
 (0)