Skip to content

Commit 23e2f44

Browse files
committed
add int8/tf32 transpose A copy traits
1 parent 7aed740 commit 23e2f44

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

include/cute/arch/copy_xe_U32.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,27 @@ struct XE_2D_U32x16x8_LD_T {
429429
};
430430
};
431431

432+
struct XE_2D_TF32x8x8_LD_T {
433+
using BlockShape = Shape<_8, _8>;
434+
using ValueShape = Shape<_4, _16>;
435+
436+
static constexpr bool is_transpose = true;
437+
438+
template <class T>
439+
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
440+
int height, int pitch, intel::coord_t coord,
441+
T *dst) {
442+
#if defined(SYCL_INTEL_TARGET)
443+
static_assert(sizeof(T) == 4, "Expected T to have size 4");
444+
*reinterpret_cast<intel::uint4 *>(dst) =
445+
__builtin_IB_subgroup_block_read_flat_transpose_u32_m8k8(
446+
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
447+
#else
448+
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
449+
#endif
450+
}
451+
};
452+
432453
struct XE_2D_U32x1x16_ST_N {
433454
using BlockShape = Shape<_1, _16>;
434455

include/cute/arch/copy_xe_U8.hpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,72 @@ struct XE_2D_U8x32x64_LD_V {
451451
}
452452
};
453453

454+
struct XE_2D_U8x32x4_LD_T {
455+
using BlockShape = Shape<_4, _32>;
456+
using inst_dtype = uint8_t;
457+
static constexpr bool is_transpose = true;
458+
459+
template <class T>
460+
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
461+
int height, int pitch, intel::coord_t coord,
462+
T *dst) {
463+
#if defined(SYCL_INTEL_TARGET)
464+
static_assert(sizeof(T) == 1, "Expected T to have size 1");
465+
detail::XeSubgroup2DBlockLoadTranspose<4, 4, 16, 1>{}(baseoffset, width, height, pitch, coord, dst);
466+
467+
*reinterpret_cast<intel::ushort4 *>(dst) =
468+
__builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k4(
469+
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
470+
#else
471+
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
472+
#endif
473+
}
474+
};
475+
476+
struct XE_2D_U8x32x8_LD_T {
477+
using BlockShape = Shape<_8, _32>;
478+
using inst_dtype = uint8_t;
479+
static constexpr bool is_transpose = true;
480+
481+
template <class T>
482+
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
483+
int height, int pitch, intel::coord_t coord,
484+
T *dst) {
485+
#if defined(SYCL_INTEL_TARGET)
486+
static_assert(sizeof(T) == 1, "Expected T to have size 1");
487+
detail::XeSubgroup2DBlockLoadTranspose<4, 4, 16, 1>{}(baseoffset, width, height, pitch, coord, dst);
488+
489+
*reinterpret_cast<intel::ushort8 *>(dst) =
490+
__builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k8(
491+
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
492+
#else
493+
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
494+
#endif
495+
}
496+
};
497+
498+
struct XE_2D_U8x16x32_LD_T {
499+
using BlockShape = Shape<_32, _16>;
500+
using inst_dtype = uint32_t;
501+
static constexpr bool is_transpose = true;
502+
503+
template <class T>
504+
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
505+
int height, int pitch, intel::coord_t coord,
506+
T *dst) {
507+
#if defined(SYCL_INTEL_TARGET)
508+
static_assert(sizeof(T) == 1, "Expected T to have size 2");
509+
detail::XeSubgroup2DBlockLoadTranspose<4, 4, 16, 1>{}(baseoffset, width, height, pitch, coord, dst);
510+
511+
*reinterpret_cast<intel::uint8 *>(dst) =
512+
__builtin_IB_subgroup_block_read_flat_transpose_u32_k8(
513+
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
514+
#else
515+
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
516+
#endif
517+
}
518+
};
519+
454520
struct XE_2D_U8x1x16_ST_N {
455521
using BlockShape = Shape<_1, _16>;
456522

include/cute/atom/copy_traits_xe.hpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,60 @@ struct Copy_Traits_<XE_2D_U8x32x64_LD_N::PREFETCH, args_t...>
828828
using RefLayout = DstLayout;
829829
};
830830

831+
template <class... args_t>
832+
struct Copy_Traits_<XE_2D_U8x16x32_LD_T, args_t...>
833+
: XE_2D_LD_Unpack<XE_2D_U8x16x32_LD_T, args_t...> {
834+
using ThrID = Layout<_16>;
835+
// Map from (src-thr,src-val) to bit
836+
using SrcLayout = Layout<Shape <_16,Shape <_8,_32>>,
837+
Stride< _0, Stride<_1, _8>>>;
838+
// Map from (dst-thr,dst-val) to bit
839+
using DstLayout = Layout<Shape < _16,Shape <_8,_32>>,
840+
Stride<_256,Stride< _1,_8>>>;
841+
// Reference map from (thr,val) to bit
842+
using RefLayout = DstLayout;
843+
844+
template <class... ArgT>
845+
Copy_Traits_(ArgT... args)
846+
: XE_2D_LD_Unpack<XE_2D_U8x16x32_LD_T, args_t...>(args...) {}
847+
};
848+
849+
template <class... args_t>
850+
struct Copy_Traits_<XE_2D_U8x32x8_LD_T, args_t...>
851+
: XE_2D_LD_Unpack<XE_2D_U8x32x8_LD_T, args_t...> {
852+
using ThrID = Layout<_16>;
853+
// Map from (src-thr,src-val) to bit
854+
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _8>>,
855+
Stride<_0, Stride<_1, _8, _16>>>;
856+
// Map from (dst-thr,dst-val) to bit
857+
using DstLayout = Layout<Shape < _16,Shape <_8, _2, _8>>,
858+
Stride<_256,Stride<_1, _8, _16>>>;
859+
// Reference map from (thr,val) to bit
860+
using RefLayout = DstLayout;
861+
862+
template <class... ArgT>
863+
Copy_Traits_(ArgT... args)
864+
: XE_2D_LD_Unpack<XE_2D_U8x32x8_LD_T, args_t...>(args...) {}
865+
};
866+
867+
template <class... args_t>
868+
struct Copy_Traits_<XE_2D_U8x32x4_LD_T, args_t...>
869+
: XE_2D_LD_Unpack<XE_2D_U8x32x4_LD_T, args_t...> {
870+
using ThrID = Layout<_16>;
871+
// Map from (src-thr,src-val) to bit
872+
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _4>>,
873+
Stride<_0, Stride<_1, _8, _16>>>;
874+
// Map from (dst-thr,dst-val) to bit
875+
using DstLayout = Layout<Shape < _16,Shape <_8, _2, _4>>,
876+
Stride<_256,Stride<_1, _8, _16>>>;
877+
// Reference map from (thr,val) to bit
878+
using RefLayout = DstLayout;
879+
880+
template <class... ArgT>
881+
Copy_Traits_(ArgT... args)
882+
: XE_2D_LD_Unpack<XE_2D_U8x32x4_LD_T, args_t...>(args...) {}
883+
};
884+
831885
template <class... args_t>
832886
struct Copy_Traits_<XE_2D_U16x1x16_LD_N, args_t...>
833887
: XE_2D_LD_Unpack<XE_2D_U16x1x16_LD_N, args_t...> {
@@ -1403,6 +1457,24 @@ struct Copy_Traits_<XE_2D_TF32x32x16_LD_N, args_t...>
14031457
: XE_2D_LD_Unpack<XE_2D_TF32x32x16_LD_N, args_t...>(args...) {}
14041458
};
14051459

1460+
template <class... args_t>
1461+
struct Copy_Traits_<XE_2D_TF32x8x8_LD_T, args_t...>
1462+
: XE_2D_LD_Unpack<XE_2D_TF32x8x8_LD_T, args_t...> {
1463+
using ThrID = Layout<_16>;
1464+
// Map from (src-thr,src-val) to bit
1465+
using SrcLayout = Layout<Shape <_16, Shape <_4, _32>>,
1466+
Stride< _0, Stride<_32, _1>>>;
1467+
// Map from (dst-thr,dst-val) to bit
1468+
using DstLayout = Layout<Shape <_16, Shape <_4, _32>>,
1469+
Stride< _32, Stride<_32, _1>>>;
1470+
// Reference map from (thr,val) to bit
1471+
using RefLayout = DstLayout;
1472+
1473+
template <class... ArgTs>
1474+
Copy_Traits_(ArgTs... args)
1475+
: XE_2D_LD_Unpack<XE_2D_TF32x8x8_LD_T, args_t...>(args...) {}
1476+
};
1477+
14061478
template <class... args_t>
14071479
struct Copy_Traits_<XE_2D_U32x1x16_LD_N, args_t...>
14081480
: XE_2D_LD_Unpack<XE_2D_U32x1x16_LD_N, args_t...> {
@@ -2213,6 +2285,9 @@ COPY_TRAIT_LD_DEF(XE_2D_U8x1x64_LD_N)
22132285
COPY_TRAIT_LD_DEF(XE_2D_U8x2x64_LD_N)
22142286
COPY_TRAIT_LD_DEF(XE_2D_U8x4x64_LD_N)
22152287
COPY_TRAIT_LD_DEF(XE_2D_U8x8x64_LD_N)
2288+
COPY_TRAIT_LD_DEF(XE_2D_U8x32x8_LD_T)
2289+
COPY_TRAIT_LD_DEF(XE_2D_U8x32x4_LD_T)
2290+
COPY_TRAIT_LD_DEF(XE_2D_U8x16x32_LD_T)
22162291
COPY_TRAIT_LD_DEF(XE_2D_U64x8x1_LD_T)
22172292
COPY_TRAIT_LD_DEF(XE_2D_U64x8x2_LD_T)
22182293
COPY_TRAIT_LD_DEF(XE_2D_U64x8x4_LD_T)
@@ -2233,6 +2308,7 @@ COPY_TRAIT_LD_DEF(XE_2D_TF32x1x8_LD_N)
22332308
COPY_TRAIT_LD_DEF(XE_2D_TF32x2x8_LD_N)
22342309
COPY_TRAIT_LD_DEF(XE_2D_TF32x4x8_LD_N)
22352310
COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_N)
2311+
COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_T)
22362312
COPY_TRAIT_LD_DEF(XE_2D_U32x1x16_LD_N)
22372313
COPY_TRAIT_LD_DEF(XE_2D_U32x2x16_LD_N)
22382314
COPY_TRAIT_LD_DEF(XE_2D_U32x4x16_LD_N)

0 commit comments

Comments
 (0)