Skip to content

add int8/tf32 transpose A copy traits #319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: sycl-develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions include/cute/arch/xe_copy_1B.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ SYCL_DEVICE_BUILTIN(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));

// 8bits NO transform transpose
SYCL_DEVICE_BUILTIN(
cute::intel::ushort8 __builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k8(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord, int cache = 0));
SYCL_DEVICE_BUILTIN(
cute::intel::ushort4 __builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k4(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord, int cache = 0));

// 8bits VNNI transform No transpose
SYCL_DEVICE_BUILTIN(
Expand Down Expand Up @@ -443,6 +452,67 @@ struct XE_2D_U8x32x32_LD_N {
}
};

struct XE_2D_U8x32x4_LD_T {
using BlockShape = Shape<_4, _32>;
using inst_dtype = uint8_t;
static constexpr bool is_transpose = true;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(SYCL_INTEL_TARGET)
static_assert(sizeof(T) == 1, "Expected T to have size 1");
*reinterpret_cast<intel::ushort4 *>(dst) =
__builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k4(
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
#else
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
#endif
}
};

struct XE_2D_U8x32x8_LD_T {
using BlockShape = Shape<_8, _32>;
using inst_dtype = uint8_t;
static constexpr bool is_transpose = true;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(SYCL_INTEL_TARGET)
static_assert(sizeof(T) == 1, "Expected T to have size 1");
*reinterpret_cast<intel::ushort8 *>(dst) =
__builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k8(
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
#else
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
#endif
}
};

struct XE_2D_U8x16x32_LD_T {
using BlockShape = Shape<_32, _16>;
using inst_dtype = uint32_t;
static constexpr bool is_transpose = true;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(SYCL_INTEL_TARGET)
static_assert(sizeof(T) == 1, "Expected T to have size 2");
*reinterpret_cast<intel::uint8 *>(dst) =
__builtin_IB_subgroup_block_read_flat_transpose_u32_k8(
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
#else
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
#endif
}
};


struct XE_2D_U4x16x16_LD_T {
using BlockShape = Shape<_16, _16>;
using inst_dtype = uint32_t;
Expand Down
39 changes: 33 additions & 6 deletions include/cute/arch/xe_copy_4B.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ SYCL_DEVICE_BUILTIN(
int pitch_minus_one, cute::intel::coord_t coord));

// 32bits No transform No transpose
SYCL_DEVICE_BUILTIN(cute::intel::uint __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));
SYCL_DEVICE_BUILTIN(
cute::intel::uint __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));
SYCL_DEVICE_BUILTIN(
cute::intel::uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
Expand All @@ -142,9 +143,10 @@ SYCL_DEVICE_BUILTIN(
int pitch_minus_one, cute::intel::coord_t coord));

// 32bits No transform Transpose
SYCL_DEVICE_BUILTIN(cute::intel::uint __builtin_IB_subgroup_block_read_flat_transpose_u32_k1(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));
SYCL_DEVICE_BUILTIN(
cute::intel::uint __builtin_IB_subgroup_block_read_flat_transpose_u32_k1(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));
SYCL_DEVICE_BUILTIN(
cute::intel::uint2 __builtin_IB_subgroup_block_read_flat_transpose_u32_k2(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
Expand All @@ -157,6 +159,10 @@ SYCL_DEVICE_BUILTIN(
cute::intel::uint8 __builtin_IB_subgroup_block_read_flat_transpose_u32_k8(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));
SYCL_DEVICE_BUILTIN(
cute::intel::uint4 __builtin_IB_subgroup_block_read_flat_transpose_u32_m8k8(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));

// 32bits
SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(
Expand Down Expand Up @@ -710,6 +716,27 @@ struct XE_2D_U32x16x8_LD_T {
};
};

struct XE_2D_TF32x8x8_LD_T {
using BlockShape = Shape<_8, _8>;
using ValueShape = Shape<_4, _16>;

static constexpr bool is_transpose = true;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(SYCL_INTEL_TARGET)
static_assert(sizeof(T) == 4, "Expected T to have size 4");
*reinterpret_cast<intel::uint4 *>(dst) =
__builtin_IB_subgroup_block_read_flat_transpose_u32_m8k8(
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
#else
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
#endif
}
};

struct XE_2D_U32x1x16_ST_N {
using BlockShape = Shape<_1, _16>;

Expand Down
76 changes: 76 additions & 0 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,24 @@ struct Copy_Traits_<XE_2D_TF32x32x16_LD_N, args_t...>
: XE_2D_LD_Unpack<XE_2D_TF32x32x16_LD_N, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_TF32x8x8_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_TF32x8x8_LD_T, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16, Shape <_4, _32>>,
Stride< _0, Stride<_32, _1>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_16, Shape <_4, _32>>,
Stride< _32, Stride<_32, _1>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;

template <class... ArgTs>
Copy_Traits_(ArgTs... args)
: XE_2D_LD_Unpack<XE_2D_TF32x8x8_LD_T, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_U32x1x16_LD_N, args_t...>
: XE_2D_LD_Unpack<XE_2D_U32x1x16_LD_N, args_t...> {
Expand Down Expand Up @@ -1688,6 +1706,60 @@ struct Copy_Traits_<XE_2D_U16x16x16_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_U16x16x16_LD_T, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_U8x16x32_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_U8x16x32_LD_T, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,Shape <_8,_32>>,
Stride< _0, Stride<_1, _8>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape < _16,Shape <_8,_32>>,
Stride<_256,Stride< _1,_8>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;

template <class... ArgT>
Copy_Traits_(ArgT... args)
: XE_2D_LD_Unpack<XE_2D_U8x16x32_LD_T, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_U8x32x8_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_U8x32x8_LD_T, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _8>>,
Stride<_0, Stride<_1, _8, _16>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape < _16,Shape <_8, _2, _8>>,
Stride<_256,Stride<_1, _8, _16>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;

template <class... ArgT>
Copy_Traits_(ArgT... args)
: XE_2D_LD_Unpack<XE_2D_U8x32x8_LD_T, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits_<XE_2D_U8x32x4_LD_T, args_t...>
: XE_2D_LD_Unpack<XE_2D_U8x32x4_LD_T, args_t...> {
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _4>>,
Stride<_0, Stride<_1, _8, _16>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape < _16,Shape <_8, _2, _4>>,
Stride<_256,Stride<_1, _8, _16>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;

template <class... ArgT>
Copy_Traits_(ArgT... args)
: XE_2D_LD_Unpack<XE_2D_U8x32x4_LD_T, args_t...>(args...) {}
};

// template<class... args_t>
// struct Copy_Traits<XE_2D_U32x16x1_LD_T, args_t...>
// : XE_2D_LD_Unpack<XE_2D_U32x16x1_LD_T, args_t...> {
Expand Down Expand Up @@ -2232,6 +2304,9 @@ COPY_TRAIT_LD_DEF(XE_2D_U8x16x32_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U8x32x32_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U8x16x64_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U8x32x64_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U8x32x8_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U8x32x4_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U8x16x32_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U16x1x16_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U16x2x16_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U16x4x16_LD_N)
Expand Down Expand Up @@ -2274,6 +2349,7 @@ COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_V)
COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_TF32x16x16_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_TF32x32x16_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_T)
COPY_TRAIT_LD_DEF(XE_2D_U4x32x64_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U4x16x64_LD_N)
COPY_TRAIT_LD_DEF(XE_2D_U4x32x16_LD_T)
Expand Down
Loading