Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 70db702

Browse files
sunjiweiswiftDDEle
authored andcommitted
save
1 parent 90662a8 commit 70db702

File tree

3 files changed

+38
-47
lines changed

3 files changed

+38
-47
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ else()
6767
endif()
6868

6969
add_compile_options(-fsycl -fsycl-device-code-split=per_kernel)
70-
add_compile_options(-Wall -Wextra -Werror)
70+
add_compile_options(-Wall -Wextra )
7171

7272
include(ProcessorCount)
7373
ProcessorCount(nproc)

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ tile_load(tile_t& tile, payload_t& payload) {
119119
static constexpr uint32_t max_load_width_in_elem =
120120
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
121121

122-
// static constexpr uint32_t max_trans_load_height_in_elem =
123-
// load_store_attr::max_trans_load_height_in_elem;
122+
static constexpr uint32_t max_trans_load_height_in_elem =
123+
load_store_attr::max_trans_load_height_in_elem;
124+
124125
static constexpr uint32_t max_load_height_in_elem =
125126
load_store_attr::max_load_height_in_elem;
126127

@@ -130,11 +131,25 @@ tile_load(tile_t& tile, payload_t& payload) {
130131
static constexpr uint32_t elems_per_reg =
131132
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
132133

134+
static constexpr uint32_t max_ld_blk_width_in_elem =
135+
trans ? max_trans_load_width_in_elem : max_load_width_in_elem;
136+
137+
static constexpr uint32_t max_ld_blk_height_in_elem =
138+
trans ? max_trans_load_height_in_elem : max_load_height_in_elem;
139+
140+
static constexpr uint32_t ld_blk_width = std::min(
141+
mem_transpose ? block_size_y : block_size_x, max_ld_blk_width_in_elem);
142+
static constexpr uint32_t ld_blk_height = std::min(
143+
mem_transpose ? block_size_x : block_size_y, max_ld_blk_height_in_elem);
144+
145+
static constexpr uint32_t ld_blk_size_y =
146+
mem_transpose ? ld_blk_width : ld_blk_height;
147+
133148
static constexpr uint32_t ld_blk_size_y_limit =
134149
mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
135-
static constexpr uint32_t ld_blk_size_y = reg_transpose
136-
? block_size_y
137-
: std::min(ld_blk_size_y_limit, block_size_y);
150+
// static constexpr uint32_t ld_blk_size_y = reg_transpose
151+
// ? block_size_y
152+
// : std::min(ld_blk_size_y_limit, block_size_y);
138153

139154
// array len is used to make sure memory load is cache line aligned
140155
// disabled while register or memory transpose
@@ -198,10 +213,10 @@ tile_load(tile_t& tile, payload_t& payload) {
198213
constexpr uint32_t load_block_elems = block_elems * arr_len;
199214
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
200215
(i * num_block_x + j) * block_elems);
201-
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
216+
constexpr uint32_t ld_blk_size_y_pad = (reg_transpose && trans)
202217
? detail::getNextPowerOf2<ld_blk_size_y>()
203218
: ld_blk_size_y;
204-
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
219+
constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
205220
xetla_vector<dtype, tmp_size> reg_tmp;
206221
#pragma unroll
207222
for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
@@ -213,10 +228,8 @@ tile_load(tile_t& tile, payload_t& payload) {
213228
mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
214229
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
215230
native_type_t<load_dtype>,
216-
(trans ? ld_blk_size_y : block_size_x) / scale_factor,
217-
(trans ? block_size_x : ld_blk_size_y),
218-
// block_size_x / scale_factor,
219-
// ld_blk_size_y,
231+
ld_blk_width / scale_factor,
232+
ld_blk_height,
220233
arr_len,
221234
trans,
222235
mem_transform,
@@ -261,11 +274,6 @@ tile_load(tile_t& tile, payload_t& payload) {
261274
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
262275
constexpr uint8_t block_height =
263276
mem_transpose ? block_size_x : remained_blk_size_y;
264-
// constexpr uint32_t block_widthx_widthy_arrlen =
265-
// (block_width - 1) | ((block_height - 1) << 8);
266-
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
267-
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
268-
269277
reg_blk.xetla_select<load_elems, 1>(remained_start)
270278
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
271279
native_type_t<load_dtype>,
@@ -283,15 +291,6 @@ tile_load(tile_t& tile, payload_t& payload) {
283291
payload.surface_pitch,
284292
payload.offset_x + offset_x / scale_factor,
285293
payload.offset_y + offset_y + remained_start_y);
286-
287-
// xetla_tload_global<
288-
// load_dtype,
289-
// (load_elems / scale_factor),
290-
// L1,
291-
// L2,
292-
// trans,
293-
// mem_transform,
294-
// arch_tag>(tdesc);
295294
}
296295
}
297296
}
@@ -304,24 +303,16 @@ tile_load(tile_t& tile, payload_t& payload) {
304303
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
305304
? ld_blk_size_y_limit
306305
: remained_size_y;
307-
// auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
308-
// num_block_y * num_block_x, 0);
309-
// detail::reset_tile_desc_core<
310-
// num_block_x,
311-
// block_size_x,
312-
// remained_ld_blk_size_y,
313-
// scale_factor,
314-
// arr_len,
315-
// mem_transpose>(payload_row);
306+
316307
#pragma unroll
317308
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
318309
int32_t offset_x = j * block_size_x;
319310
// xetla_tdescriptor tdesc = payload_row.row(j);
320311
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
321312
processed_elems + j * remained_block_elems);
322-
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
323-
? detail::getNextPowerOf2<remained_ld_blk_size_y>()
324-
: remained_ld_blk_size_y;
313+
// constexpr uint32_t ld_blk_height = (reg_transpose && trans)
314+
// ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
315+
// : remained_ld_blk_size_y;
325316
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
326317
xetla_vector<dtype, tmp_size> reg_tmp;
327318
#pragma unroll

tests/integration/gemm/fp32/main.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ TYPED_TEST_P(fp32_gemm_test, esimd) {
3434

3535
REGISTER_TYPED_TEST_SUITE_P(fp32_gemm_test, esimd);
3636
using tests = ::testing::Types<
37-
// Test1,
38-
// Test2,
39-
// Test3,
40-
// Test4,
41-
// Test5,
42-
// Test6,
43-
// Test7,
44-
// Test8,
45-
// Test9,
37+
Test1,
38+
Test2,
39+
Test3,
40+
Test4,
41+
Test5,
42+
Test6,
43+
Test7,
44+
Test8,
45+
Test9,
4646
Test10,
4747
Test11>;
4848
INSTANTIATE_TYPED_TEST_SUITE_P(fp32_gemm_test_suite, fp32_gemm_test, tests);

0 commit comments

Comments
 (0)