Skip to content
This repository has been archived by the owner on Dec 18, 2024. It is now read-only.

Commit

Permalink
fix reg_blk and reg_tmp size misaligned issue (#559)
Browse files Browse the repository at this point in the history
Co-authored-by: Jacky, Deng <[email protected]>
  • Loading branch information
2 people authored and taozha2 committed Jun 14, 2023
1 parent 971a2ee commit 76c593f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
12 changes: 8 additions & 4 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ tile_load(tile_t &tile, payload_t &payload) {
constexpr uint32_t ld_blk_height = reg_transpose
? detail::getNextPowerOf2<ld_blk_size_y>()
: ld_blk_size_y;
xetla_vector<dtype, ld_blk_height * block_size_x * arr_len> reg_tmp;
constexpr uint32_t tmp_size
= ld_blk_height * block_size_x * arr_len;
xetla_vector<dtype, tmp_size> reg_tmp;
#pragma unroll
for (int ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
constexpr uint32_t load_elems
Expand All @@ -267,7 +269,7 @@ tile_load(tile_t &tile, payload_t &payload) {
.xetla_select<block_size_x, 1,
ld_blk_size_y, 1>(0, 0);
} else {
reg_blk = reg_tmp;
reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
}

if constexpr (mem_transpose) {
Expand Down Expand Up @@ -332,7 +334,9 @@ tile_load(tile_t &tile, payload_t &payload) {
constexpr uint32_t ld_blk_height = reg_transpose
? detail::getNextPowerOf2<remained_ld_blk_size_y>()
: remained_ld_blk_size_y;
xetla_vector<dtype, ld_blk_height * block_size_x * arr_len> reg_tmp;
constexpr uint32_t tmp_size
= ld_blk_height * block_size_x * arr_len;
xetla_vector<dtype, tmp_size> reg_tmp;
#pragma unroll
for (int ii = 0; ii < remained_size_y / remained_ld_blk_size_y;
++ii) {
Expand All @@ -353,7 +357,7 @@ tile_load(tile_t &tile, payload_t &payload) {
.xetla_select<block_size_x, 1,
remained_ld_blk_size_y, 1>(0, 0);
} else {
reg_blk = reg_tmp;
reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
}
if constexpr (mem_transpose) {
xetla_update_tdesc_offsetx(tdesc.xetla_format<uint32_t>(),
Expand Down
12 changes: 11 additions & 1 deletion tests/unit/tile_load_store/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ TEST(tile_load_store, esimd) {
nd_range, result_validate);
}

TEST(tile_load_transpose_store, esimd) {
TEST(tile_load_transpose_store_1, esimd) {
cl::sycl::nd_range<1> nd_range({1}, {1});
auto result_validate
= std::bind(tile_load_store_result_validate<int, false, true>, _1,
Expand All @@ -175,6 +175,16 @@ TEST(tile_load_transpose_store, esimd) {
64>>(nd_range, result_validate);
}

TEST(tile_load_transpose_store_2, esimd) {
cl::sycl::nd_range<1> nd_range({1}, {1});
auto result_validate
= std::bind(tile_load_store_result_validate<int, false, true>, _1,
_2, _3, 128, 64, 32, 32);
kernel_run<int,
tile_load_store_func<int, 128, 64, 128, 32, 32, 8, 16, false, true,
64>>(nd_range, result_validate);
}

TEST(tile_load_transform_store, esimd) {
cl::sycl::nd_range<1> nd_range({1}, {1});
auto result_validate
Expand Down

0 comments on commit 76c593f

Please sign in to comment.