Skip to content

add int8/tf32 transpose A copy traits #319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/cute/arch/copy_xe_U32.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,25 @@ struct XE_2D_U32x16x8_LD_T {
};
};

struct XE_2D_TF32x8x8_LD_T {
using BlockShape = Shape<_8, _8>;
using ValueShape = Shape<_4, _16>;

static constexpr bool is_transpose = true;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(CUTE_ARCH_COPY_XE_ENABLED)
static_assert(sizeof(T) == 4, "Expected T to have size 4");
detail::XeSubgroup2DBlockLoadTranspose<4, 8, 8, 1>{}(baseoffset, width, height, pitch, coord, dst);
#else
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
#endif
}
};

struct XE_2D_U32x1x16_ST_N {
using BlockShape = Shape<_1, _16>;

Expand Down
54 changes: 54 additions & 0 deletions include/cute/arch/copy_xe_U8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,60 @@ struct XE_2D_U8x32x64_LD_V {
}
};

struct XE_2D_U8x32x4_LD_T {
using BlockShape = Shape<_4, _32>;
using inst_dtype = uint8_t;
static constexpr bool is_transpose = true;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(CUTE_ARCH_COPY_XE_ENABLED)
static_assert(sizeof(T) == 1, "Expected T to have size 1");
detail::XeSubgroup2DBlockLoadTranspose<1, 4, 32, 1>{}(baseoffset, width, height, pitch, coord, dst);
#else
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
#endif
}
};

struct XE_2D_U8x32x8_LD_T {
using BlockShape = Shape<_8, _32>;
using inst_dtype = uint8_t;
static constexpr bool is_transpose = true;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(CUTE_ARCH_COPY_XE_ENABLED)
static_assert(sizeof(T) == 1, "Expected T to have size 1");
detail::XeSubgroup2DBlockLoadTranspose<1, 8, 32, 1>{}(baseoffset, width, height, pitch, coord, dst);
#else
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
#endif
}
};

struct XE_2D_U8x16x32_LD_T {
using BlockShape = Shape<_32, _16>;
using inst_dtype = uint32_t;
static constexpr bool is_transpose = true;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(CUTE_ARCH_COPY_XE_ENABLED)
static_assert(sizeof(T) == 1, "Expected T to have size 2");
detail::XeSubgroup2DBlockLoadTranspose<4, 8, 16, 1>{}(baseoffset, width, height, pitch, coord, dst);
#else
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
#endif
}
};

struct XE_2D_U8x1x16_ST_N {
using BlockShape = Shape<_1, _16>;

Expand Down
48 changes: 48 additions & 0 deletions include/cute/arch/copy_xe_builtin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ SYCL_DEVICE_BUILTIN(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));

// 8bits No transform Transpose
SYCL_DEVICE_BUILTIN(
cute::intel::ushort4 __builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k4(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord, int cacheOpt = 0));
SYCL_DEVICE_BUILTIN(
cute::intel::ushort8 __builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k8(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord, int cacheOpt = 0));

// 8bits No transform No transpose
SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u8_m1k16v1(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
Expand Down Expand Up @@ -191,6 +201,7 @@ SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u8_m4k32v1(
SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u8_m8k32v1(
long baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord, enum CacheControl cache_control));

// // 2D prefetch
SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_8b_1r32x2c(
__global void* base_address, int width, int height, int pitch,
Expand Down Expand Up @@ -422,6 +433,10 @@ SYCL_DEVICE_BUILTIN(
cute::intel::uint8 __builtin_IB_subgroup_block_read_flat_transpose_u32_k8(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));
SYCL_DEVICE_BUILTIN(
cute::intel::ushort8 __builtin_IB_subgroup_block_read_cacheopts_transpose_u32_m8k8(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord, int cacheOpt = 0));

// 32bits
SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(
Expand Down Expand Up @@ -622,6 +637,28 @@ struct XeSubgroup2DBlockLoadTransform<1, 16, 32, 4> {
}
};

template<>
struct XeSubgroup2DBlockLoadTranspose<1, 4, 32, 1> {
template<typename T>
CUTE_HOST_DEVICE void
operator()(const void* srcBasePointer, int memoryWidth, int memoryHeight, int memoryPitch,
cute::intel::coord_t coordinate, T* dstPointer) {
*reinterpret_cast<intel::ushort4 *>(dstPointer) = __builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k4(
reinterpret_cast<long>(srcBasePointer), memoryWidth - 1, memoryHeight - 1, memoryPitch - 1, coordinate);
}
};

template<>
struct XeSubgroup2DBlockLoadTranspose<1, 8, 32, 1> {
template<typename T>
CUTE_HOST_DEVICE void
operator()(const void* srcBasePointer, int memoryWidth, int memoryHeight, int memoryPitch,
cute::intel::coord_t coordinate, T* dstPointer) {
*reinterpret_cast<intel::ushort8 *>(dstPointer) = __builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k8(
reinterpret_cast<long>(srcBasePointer), memoryWidth - 1, memoryHeight - 1, memoryPitch - 1, coordinate);
}
};

template<>
struct XeSubgroup2DBlockStore<1, 16, 1, 1> {
template<typename T>
Expand Down Expand Up @@ -1319,6 +1356,17 @@ struct XeSubgroup2DBlockLoadTranspose<4, 8, 16, 1> {
}
};

template<>
struct XeSubgroup2DBlockLoadTranspose<4, 8, 8, 1> {
template<typename T>
CUTE_HOST_DEVICE void
operator()(const void* srcBasePointer, int memoryWidth, int memoryHeight, int memoryPitch,
cute::intel::coord_t coordinate, T* dstPointer) {
*reinterpret_cast<intel::ushort8 *>(dstPointer) = __builtin_IB_subgroup_block_read_cacheopts_transpose_u32_m8k8(
reinterpret_cast<long>(srcBasePointer), memoryWidth - 1, memoryHeight - 1, memoryPitch - 1, coordinate);
}
};

template<>
struct XeSubgroup2DBlockStore<4, 16, 1, 1> {
template<typename T>
Expand Down
76 changes: 76 additions & 0 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,60 @@ struct Copy_Traits_<XE_2D_U8x32x64_LD_N::PREFETCH, args_t...>
using RefLayout = DstLayout;
};

template <class... args_t>
struct Copy_Traits_<XE_2D_U8x16x32_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_U8x16x32_LD_T, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,Shape <_8,_32>>,
Stride< _0, Stride<_1, _8>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape < _16,Shape <_8,_32>>,
Stride<_256,Stride< _1,_8>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;

template <class... ArgT>
Copy_Traits_(ArgT... args)
: XE_2D_LD_Unpack<XE_2D_U8x16x32_LD_T, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_U8x32x8_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_U8x32x8_LD_T, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _8>>,
Stride<_0, Stride<_1, _8, _16>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape < _16,Shape <_8, _2, _8>>,
Stride<_256,Stride<_1, _8, _16>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;

template <class... ArgT>
Copy_Traits_(ArgT... args)
: XE_2D_LD_Unpack<XE_2D_U8x32x8_LD_T, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_U8x32x4_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_U8x32x4_LD_T, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _4>>,
Stride<_0, Stride<_1, _8, _16>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape < _16,Shape <_8, _2, _4>>,
Stride<_256,Stride<_1, _8, _16>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;

template <class... ArgT>
Copy_Traits_(ArgT... args)
: XE_2D_LD_Unpack<XE_2D_U8x32x4_LD_T, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_U16x1x16_LD_N, args_t...>
: XE_2D_LD_Unpack<XE_2D_U16x1x16_LD_N, args_t...> {
Expand Down Expand Up @@ -1401,6 +1455,24 @@ struct Copy_Traits_<XE_2D_TF32x32x16_LD_N, args_t...>
: XE_2D_LD_Unpack<XE_2D_TF32x32x16_LD_N, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_TF32x8x8_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_TF32x8x8_LD_T, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16, Shape <_4, _32>>,
Stride< _0, Stride<_32, _1>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_16, Shape <_4, _32>>,
Stride< _32, Stride<_32, _1>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;

template <class... ArgTs>
Copy_Traits_(ArgTs... args)
: XE_2D_LD_Unpack<XE_2D_TF32x8x8_LD_T, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_U32x1x16_LD_N, args_t...>
: XE_2D_LD_Unpack<XE_2D_U32x1x16_LD_N, args_t...> {
Expand Down Expand Up @@ -2211,6 +2283,9 @@ COPY_TRAIT_LD_DEF(XE_2D_U8x1x64_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U8x2x64_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U8x4x64_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U8x8x64_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U8x32x8_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U8x32x4_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U8x16x32_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U64x8x1_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U64x8x2_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U64x8x4_LD_T)
Expand All @@ -2231,6 +2306,7 @@ COPY_TRAIT_LD_DEF(XE_2D_TF32x1x8_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_TF32x2x8_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_TF32x4x8_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U32x1x16_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U32x2x16_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U32x4x16_LD_N)
Expand Down
Loading