32
32
33
33
#include < cute/config.hpp>
34
34
#include < cute/arch/mma.hpp>
35
- #include < cute/util/sycl_vec.hpp>
36
-
37
- #ifdef __SYCL_DEVICE_ONLY__
38
- #define SYCL_DEVICE_OCL (x ) SYCL_EXTERNAL x
39
- #else
40
- #define SYCL_DEVICE_OCL (x ) inline x { CUTE_INVALID_CONTROL_PATH (" Trying to use XE built-in on non-XE hardware" ); }
41
- #endif
35
+ #include < cute/arch/xe_config.hpp>
42
36
43
37
// mma_bf16
44
38
SYCL_DEVICE_OCL (cute::intel::float8 intel_sub_group_bf16_bf16_matrix_mad_k16 (cute::intel::short8 a, cute::intel::int8 b, cute::intel::float8 acc));
@@ -66,7 +60,125 @@ SYCL_DEVICE_OCL(cute::intel::float4 intel_sub_group_tf32_tf32_matrix_mad_k8(cute
66
60
SYCL_DEVICE_OCL (cute::intel::float2 intel_sub_group_tf32_tf32_matrix_mad_k8 (float a, cute::intel::float8 b, cute::intel::float2 acc));
67
61
SYCL_DEVICE_OCL (float intel_sub_group_tf32_tf32_matrix_mad_k8 (float a, cute::intel::float8 b, float acc));
68
62
69
- #undef SYCL_DEVICE_OCL
63
+ #if defined(CUTE_ARCH_MMA_XE_SPIRV_ENABLED)
64
+ namespace cute ::detail
65
+ {
66
+ template <>
67
+ struct XeSubgroupMatrixMultiplyAccumulate <float , bfloat16_t , bfloat16_t , float > {
68
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
69
+ CUTE_HOST_DEVICE
70
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
71
+ #ifdef __SYCL_DEVICE_ONLY__
72
+ return __spirv_SubgroupMatrixMultiplyAccumulateINTEL (16 , a, b, c, SPIRV_MMAOperands::SPIRV_MatrixABf16 | SPIRV_MMAOperands::SPIRV_MatrixBBf16 );
73
+ #endif
74
+ }
75
+ };
76
+
77
+ template <>
78
+ struct XeSubgroupMatrixMultiplyAccumulate <float , half_t , half_t , float > {
79
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
80
+ CUTE_HOST_DEVICE
81
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
82
+ #ifdef __SYCL_DEVICE_ONLY__
83
+ return __spirv_SubgroupMatrixMultiplyAccumulateINTEL (16 , a, b, c, SPIRV_MMAOperands::SPIRV_MatrixAFp16 | SPIRV_MMAOperands::SPIRV_MatrixBFp16);
84
+ #endif
85
+ }
86
+ };
87
+
88
+ template <>
89
+ struct XeSubgroupMatrixMultiplyAccumulate <int32_t , int8_t , int8_t , int32_t > {
90
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
91
+ CUTE_HOST_DEVICE
92
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
93
+ #ifdef __SYCL_DEVICE_ONLY__
94
+ return __spirv_SubgroupMatrixMultiplyAccumulateINTEL (32 , a, b, c, SPIRV_MMAOperands::SPIRV_MatrixASigned | SPIRV_MMAOperands::SPIRV_MatrixBSigned | SPIRV_MMAOperands::SPIRV_MatrixAInt8 | SPIRV_MMAOperands::SPIRV_MatrixBInt8);
95
+ #endif
96
+ }
97
+ };
98
+
99
+ template <>
100
+ struct XeSubgroupMatrixMultiplyAccumulate <int32_t , uint8_t , uint8_t , int32_t > {
101
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
102
+ CUTE_HOST_DEVICE
103
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
104
+ #ifdef __SYCL_DEVICE_ONLY__
105
+ return __spirv_SubgroupMatrixMultiplyAccumulateINTEL (32 , a, b, c, SPIRV_MMAOperands::SPIRV_MatrixAInt8 | SPIRV_MMAOperands::SPIRV_MatrixBInt8);
106
+ #endif
107
+ }
108
+ };
109
+
110
+ template <>
111
+ struct XeSubgroupMatrixMultiplyAccumulate <float , tfloat32_t , tfloat32_t , float > {
112
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
113
+ CUTE_HOST_DEVICE
114
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
115
+ #ifdef __SYCL_DEVICE_ONLY__
116
+ return __spirv_SubgroupMatrixMultiplyAccumulateINTEL (8 , a, b, c, SPIRV_MMAOperands::SPIRV_MatrixATf32 | SPIRV_MMAOperands::SPIRV_MatrixBTf32);
117
+ #endif
118
+ }
119
+ };
120
+ } // namespace cute::detail end
121
+ #endif
122
+
123
+ #if defined(CUTE_ARCH_MMA_XE_BUILTIN_ENABLED)
124
+ namespace cute ::detail
125
+ {
126
+ template <>
127
+ struct XeSubgroupMatrixMultiplyAccumulate <float , bfloat16_t , bfloat16_t , float > {
128
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
129
+ CUTE_HOST_DEVICE
130
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
131
+ #ifdef __SYCL_DEVICE_ONLY__
132
+ return intel_sub_group_bf16_bf16_matrix_mad_k16 (a, b, c);
133
+ #endif
134
+ }
135
+ };
136
+
137
+ template <>
138
+ struct XeSubgroupMatrixMultiplyAccumulate <float , half_t , half_t , float > {
139
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
140
+ CUTE_HOST_DEVICE
141
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
142
+ #ifdef __SYCL_DEVICE_ONLY__
143
+ return intel_sub_group_f16_f16_matrix_mad_k16 (a, b, c);
144
+ #endif
145
+ }
146
+ };
147
+
148
+ template <>
149
+ struct XeSubgroupMatrixMultiplyAccumulate <int32_t , int8_t , int8_t , int32_t > {
150
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
151
+ CUTE_HOST_DEVICE
152
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
153
+ #ifdef __SYCL_DEVICE_ONLY__
154
+ return intel_sub_group_i8_i8_matrix_mad_k32 (a, b, c);
155
+ #endif
156
+ }
157
+ };
158
+
159
+ template <>
160
+ struct XeSubgroupMatrixMultiplyAccumulate <int32_t , uint8_t , uint8_t , int32_t > {
161
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
162
+ CUTE_HOST_DEVICE
163
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
164
+ #ifdef __SYCL_DEVICE_ONLY__
165
+ return intel_sub_group_u8_u8_matrix_mad_k32 (a, b, c);
166
+ #endif
167
+ }
168
+ };
169
+
170
+ template <>
171
+ struct XeSubgroupMatrixMultiplyAccumulate <float , tfloat32_t , tfloat32_t , float > {
172
+ template <typename ARegisters, typename BRegisters, typename CRegisters>
173
+ CUTE_HOST_DEVICE
174
+ auto operator ()(ARegisters a, BRegisters b, CRegisters c) {
175
+ #ifdef __SYCL_DEVICE_ONLY__
176
+ return intel_sub_group_tf32_tf32_matrix_mad_k8 (a, b, c);
177
+ #endif
178
+ }
179
+ };
180
+ } // namespace cute::detail end
181
+ #endif
70
182
71
183
namespace cute {
72
184
// MxNxK_D,A,B,C
@@ -86,8 +198,8 @@ struct XE_8x16x16_F32BF16BF16F32_TT
86
198
intel::int8 const & b,
87
199
intel::float8 const & c)
88
200
{
89
- #if defined(SYCL_INTEL_TARGET )
90
- d = intel_sub_group_bf16_bf16_matrix_mad_k16 (a, b, c);
201
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
202
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , bfloat16_t , bfloat16_t , float >{} (a, b, c);
91
203
#else
92
204
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_8x16x16_F32BF16BF16F32_TT on non-PVC hardware" );
93
205
#endif
@@ -106,8 +218,8 @@ struct XE_4x16x16_F32BF16BF16F32_TT
106
218
intel::int8 const & b,
107
219
intel::float4 const & c)
108
220
{
109
- #if defined(SYCL_INTEL_TARGET )
110
- d = intel_sub_group_bf16_bf16_matrix_mad_k16 (a, b, c);
221
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
222
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , bfloat16_t , bfloat16_t , float >{} (a, b, c);
111
223
#else
112
224
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_8x16x16_F32BF16BF16F32_TT on non-PVC hardware" );
113
225
#endif
@@ -126,8 +238,8 @@ struct XE_2x16x16_F32BF16BF16F32_TT
126
238
intel::int8 const & b,
127
239
intel::float2 const & c)
128
240
{
129
- #if defined(SYCL_INTEL_TARGET )
130
- d = intel_sub_group_bf16_bf16_matrix_mad_k16 (a, b, c);
241
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
242
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , bfloat16_t , bfloat16_t , float >{} (a, b, c);
131
243
#else
132
244
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_8x16x16_F32BF16BF16F32_TT on non-PVC hardware" );
133
245
#endif
@@ -147,8 +259,8 @@ struct XE_1x16x16_F32BF16BF16F32_TT
147
259
intel::int8 const & b,
148
260
float const & c)
149
261
{
150
- #if defined(SYCL_INTEL_TARGET )
151
- d = intel_sub_group_bf16_bf16_matrix_mad_k16 (a, b, c);
262
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
263
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , bfloat16_t , bfloat16_t , float >{} (a, b, c);
152
264
#else
153
265
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_1x16x16_F32BF16BF16F32_TT on non-PVC hardware" );
154
266
#endif
@@ -172,8 +284,8 @@ struct XE_8x16x16_F32F16F16F32_TT
172
284
intel::int8 const & b,
173
285
intel::float8 const & c)
174
286
{
175
- #if defined(SYCL_INTEL_TARGET )
176
- d = intel_sub_group_f16_f16_matrix_mad_k16 (a, b, c);
287
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
288
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , half_t , half_t , float >{} (a, b, c);
177
289
#else
178
290
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_8x16x16_F32F16F16F32_TT on non-PVC hardware" );
179
291
#endif
@@ -193,8 +305,8 @@ struct XE_4x16x16_F32F16F16F32_TT
193
305
intel::int8 const & b,
194
306
intel::float4 const & c)
195
307
{
196
- #if defined(SYCL_INTEL_TARGET )
197
- d = intel_sub_group_f16_f16_matrix_mad_k16 (a, b, c);
308
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
309
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , half_t , half_t , float >{} (a, b, c);
198
310
#else
199
311
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_4x16x16_F32F16F16F32_TT on non-PVC hardware" );
200
312
#endif
@@ -214,8 +326,8 @@ struct XE_2x16x16_F32F16F16F32_TT
214
326
intel::int8 const & b,
215
327
intel::float2 const & c)
216
328
{
217
- #if defined(SYCL_INTEL_TARGET )
218
- d = intel_sub_group_f16_f16_matrix_mad_k16 (a, b, c);
329
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
330
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , half_t , half_t , float >{} (a, b, c);
219
331
#else
220
332
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_2x16x16_F32F16F16F32_TT on non-PVC hardware" );
221
333
#endif
@@ -235,8 +347,8 @@ struct XE_1x16x16_F32F16F16F32_TT
235
347
intel::int8 const & b,
236
348
float const & c)
237
349
{
238
- #if defined(SYCL_INTEL_TARGET )
239
- d = intel_sub_group_f16_f16_matrix_mad_k16 (a, b, c);
350
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
351
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , half_t , half_t , float >{} (a, b, c);
240
352
#else
241
353
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_1x16x16_F32F16F16F32_TT on non-PVC hardware" );
242
354
#endif
@@ -260,8 +372,8 @@ struct XE_8x16x32_S32S8S8S32_TT
260
372
intel::int8 const & b,
261
373
intel::int8 const & c)
262
374
{
263
- #if defined(SYCL_INTEL_TARGET )
264
- d = intel_sub_group_i8_i8_matrix_mad_k32 (a, b, c);
375
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
376
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< int32_t , int8_t , int8_t , int32_t >{} (a, b, c);
265
377
#else
266
378
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_8x16x32_S32S8S8S32_TT on non-PVC hardware" );
267
379
#endif
@@ -281,8 +393,8 @@ struct XE_4x16x32_S32S8S8S32_TT
281
393
intel::int8 const & b,
282
394
intel::int4 const & c)
283
395
{
284
- #if defined(SYCL_INTEL_TARGET )
285
- d = intel_sub_group_i8_i8_matrix_mad_k32 (a, b, c);
396
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
397
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< int32_t , int8_t , int8_t , int32_t >{} (a, b, c);
286
398
#else
287
399
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_4x16x32_S32S8S8S32_TT on non-PVC hardware" );
288
400
#endif
@@ -302,8 +414,8 @@ struct XE_2x16x32_S32S8S8S32_TT
302
414
intel::int8 const & b,
303
415
intel::int2 const & c)
304
416
{
305
- #if defined(SYCL_INTEL_TARGET )
306
- d = intel_sub_group_i8_i8_matrix_mad_k32 (a, b, c);
417
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
418
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< int32_t , int8_t , int8_t , int32_t >{} (a, b, c);
307
419
#else
308
420
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_2x16x32_S32S8S8S32_TT on non-PVC hardware" );
309
421
#endif
@@ -323,8 +435,8 @@ struct XE_1x16x32_S32S8S8S32_TT
323
435
intel::int8 const & b,
324
436
int const & c)
325
437
{
326
- #if defined(SYCL_INTEL_TARGET )
327
- d = intel_sub_group_i8_i8_matrix_mad_k32 (a, b, c);
438
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
439
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< int32_t , int8_t , int8_t , int32_t >{} (a, b, c);
328
440
#else
329
441
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_1x16x32_S32S8S8S32_TT on non-PVC hardware" );
330
442
#endif
@@ -344,8 +456,8 @@ struct XE_8x16x32_S32U8U8S32_TT
344
456
intel::uint8 const & b,
345
457
intel::int8 const & c)
346
458
{
347
- #if defined(SYCL_INTEL_TARGET )
348
- d = intel_sub_group_u8_u8_matrix_mad_k32 (a, b, c);
459
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
460
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< int32_t , uint8_t , uint8_t , int32_t >{} (a, b, c);
349
461
#else
350
462
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_8x16x32_S32U8U8S32_TT on non-PVC hardware" );
351
463
#endif
@@ -365,8 +477,8 @@ struct XE_4x16x32_S32U8U8S32_TT
365
477
intel::uint8 const & b,
366
478
intel::int4 const & c)
367
479
{
368
- #if defined(SYCL_INTEL_TARGET )
369
- d = intel_sub_group_u8_u8_matrix_mad_k32 (a, b, c);
480
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
481
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< int32_t , uint8_t , uint8_t , int32_t >{} (a, b, c);
370
482
#else
371
483
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_4x16x32_S32U8U8S32_TT on non-PVC hardware" );
372
484
#endif
@@ -386,8 +498,8 @@ struct XE_2x16x32_S32U8U8S32_TT
386
498
intel::uint8 const & b,
387
499
intel::int2 const & c)
388
500
{
389
- #if defined(SYCL_INTEL_TARGET )
390
- d = intel_sub_group_u8_u8_matrix_mad_k32 (a, b, c);
501
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
502
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< int32_t , uint8_t , uint8_t , int32_t >{} (a, b, c);
391
503
#else
392
504
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_2x16x32_S32U8U8S32_TT on non-PVC hardware" );
393
505
#endif
@@ -407,8 +519,8 @@ struct XE_1x16x32_S32U8U8S32_TT
407
519
intel::uint8 const & b,
408
520
int const & c)
409
521
{
410
- #if defined(SYCL_INTEL_TARGET )
411
- d = intel_sub_group_u8_u8_matrix_mad_k32 (a, b, c);
522
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
523
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< int32_t , uint8_t , uint8_t , int32_t >{} (a, b, c);
412
524
#else
413
525
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_1x16x32_S32U8U8S32_TT on non-PVC hardware" );
414
526
#endif
@@ -428,8 +540,8 @@ struct XE_8x16x8_F32TF32TF32F32_TT
428
540
intel::float8 const & b,
429
541
intel::float8 const & c)
430
542
{
431
- #if defined(SYCL_INTEL_TARGET )
432
- d = intel_sub_group_tf32_tf32_matrix_mad_k8 (a, b, c);
543
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
544
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , tfloat32_t , tfloat32_t , float >{} (a, b, c);
433
545
#else
434
546
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_8x16x8_F32TF32TF32F32_TT on non-PVC hardware" );
435
547
#endif
@@ -449,8 +561,8 @@ struct XE_4x16x8_F32TF32TF32F32_TT
449
561
intel::float8 const & b,
450
562
intel::float4 const & c)
451
563
{
452
- #if defined(SYCL_INTEL_TARGET )
453
- d = intel_sub_group_tf32_tf32_matrix_mad_k8 (a, b, c);
564
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
565
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , tfloat32_t , tfloat32_t , float >{} (a, b, c);
454
566
#else
455
567
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_4x16x8_F32TF32TF32F32_TT on non-PVC hardware" );
456
568
#endif
@@ -470,8 +582,8 @@ struct XE_2x16x8_F32TF32TF32F32_TT
470
582
intel::float8 const & b,
471
583
intel::float2 const & c)
472
584
{
473
- #if defined(SYCL_INTEL_TARGET )
474
- d = intel_sub_group_tf32_tf32_matrix_mad_k8 (a, b, c);
585
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
586
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , tfloat32_t , tfloat32_t , float >{} (a, b, c);
475
587
#else
476
588
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_2x16x8_F32TF32TF32F32_TT on non-PVC hardware" );
477
589
#endif
@@ -491,8 +603,8 @@ struct XE_1x16x8_F32TF32TF32F32_TT
491
603
intel::float8 const & b,
492
604
float const & c)
493
605
{
494
- #if defined(SYCL_INTEL_TARGET )
495
- d = intel_sub_group_tf32_tf32_matrix_mad_k8 (a, b, c);
606
+ #if defined(CUTE_ARCH_MMA_XE_ENABLED )
607
+ d = detail::XeSubgroupMatrixMultiplyAccumulate< float , tfloat32_t , tfloat32_t , float >{} (a, b, c);
496
608
#else
497
609
CUTE_INVALID_CONTROL_PATH (" Attempting to use XE_1x16x8_F32TF32TF32F32_TT on non-PVC hardware" );
498
610
#endif
0 commit comments