diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index e92d184210..5a250b6156 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -216,6 +216,7 @@ struct XE_2D_LD_Unpack { uint32_t height; uint32_t pitch; uint32_t stride_l = 0; + uint32_t height_offset = 0; @@ -253,11 +254,15 @@ struct XE_2D_LD_Unpack { if constexpr (stride_rank == 3) { stride_l = size<2>(tensor.stride()); } + if(stride_l % pitch != 0){ + CUTE_INVALID_CONTROL_PATH("Incompatible strides in tensor.\n"); + }; + height_offset = stride_l / pitch; } XE_2D_LD_Unpack(Traits_LD_t const &traits) : base_ptr(traits.base_ptr), width(traits.width), height(traits.height), pitch(traits.pitch), - stride_l(traits.stride_l) {} + stride_l(traits.stride_l), height_offset(traits.height_offset){} XE_2D_LD_Unpack() {} @@ -279,11 +284,12 @@ struct XE_2D_LD_Unpack { auto [m, n, l] = src.data().coord_; int x = is_need_reversed ? m : n; int y = is_need_reversed ? n : m; + y += l * traits.height_offset; constexpr auto inst_size_bits = detail::size_of_inst_bits; - CopyOp::copy(base_addr + l * traits.stride_l, - (traits.width * sizeof_bits_v) / sizeof_bits_v, traits.height, + CopyOp::copy(base_addr, (traits.width * sizeof_bits_v) / sizeof_bits_v, + traits.height + l * traits.height_offset, (traits.pitch * sizeof_bits_v) / sizeof_bits_v, intel::coord_t{(int)(x * sizeof_bits_v / inst_size_bits), y}, raw_pointer_cast(&((&*dst.data())[0]))); @@ -305,11 +311,12 @@ struct XE_2D_LD_Unpack { int x = is_need_reversed ? m : n; int y = is_need_reversed ? n : m; + y += l * atom.height_offset; constexpr auto inst_size_bits = detail::size_of_inst_bits; - CopyOp::PREFETCH::copy(base_addr + l * atom.stride_l, - (atom.width * sizeof_bits_v) / sizeof_bits_v, atom.height, + CopyOp::PREFETCH::copy(base_addr, (atom.width * sizeof_bits_v) / sizeof_bits_v, + atom.height + l * atom.height_offset, (atom.pitch * sizeof_bits_v) / sizeof_bits_v, intel::coord_t{(int)(x * sizeof_bits_v / inst_size_bits), y}); } @@ -339,6 +346,7 @@ template (tensor.stride()); } + if(stride_l % pitch != 0){ + CUTE_INVALID_CONTROL_PATH("Incompatible strides in tensor.\n"); + }; + height_offset = stride_l / pitch; } XE_2D_ST_Unpack(Traits_ST_t const &traits) : base_ptr(traits.base_ptr), width(traits.width), height(traits.height), pitch(traits.pitch), - stride_l(traits.stride_l) {} + stride_l(traits.stride_l), height_offset(traits.height_offset) {} XE_2D_ST_Unpack() {} @@ -383,11 +395,14 @@ template