Skip to content

Commit

Permalink
Update blocking alingment for loadOp
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 committed Dec 13, 2024
1 parent 3021ae2 commit 2556aee
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 40 deletions.
19 changes: 13 additions & 6 deletions lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,21 @@ void BlockingAnalysisImpl::visitLoadTileOp(

// adjust according to user's requirements if it is available
if (lattice.isInitialized()) {
// Always align the width dimension.
// NOTE: For transpose usecase, we still align the width dimension. This is
// because loads with transpose cannot have array_length > 1, plus it has HW
// limitations on supported width. If we align the height dimension (for
// reducing reg data movement), it will lead to multiple smaller loads.
for (auto rq : lattice.getRequests())
bool hasTransposeUser = op.getValue().hasOneUse() &&
mlir::isa<xetile::TransposeOp>(*(op->user_begin()));

// To minimize the in-reg data movement, we need to align dim1 for regular
// case and dim0 for transpose case. For transpose case, we also need to
// make sure dim1 such that the following pass can fold the transpose with
// the load.
for (auto rq : lattice.getRequests()) {
if (rq[1] && ((rq[1] * bitWidth) % 32 == 0)) // has to be 32-bit aligned
block[1] = std::min(block[1], rq[1]);

// also aligns the height dimension if user is a transpose op.
if (hasTransposeUser)
block[0] = std::min(block[0], rq[0]);
}
}

if (!block)
Expand Down
12 changes: 6 additions & 6 deletions test/Conversion/XeTileToXeGPU/sg_gemm_transpose_b.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ gpu.module @test_kernel {
%c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32>
%c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<32x32xf32> -> vector<32x32xf32>
%a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>
// CHECk-COUNT-2: %{{.*}} = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>
%b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
// CHECK: scf.for %{{.*}}= %{{.*}}to %{{.*}}step %{{.*}}iter_args(%{{.*}}= %{{.*}}, %[[ARG5:.*]] = %[[T1]], %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}} = %{{.*}}) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
// CHECK: %{{.*}}:11 = scf.for {{.*}} -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
%out:3 = scf.for %k = %c0 to %c1024 step %c32
iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value)
-> (!xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>) {
%a_value = xetile.load_tile %a_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16>
// CHECK: xegpu.load_nd %[[ARG5]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>> -> vector<2x32x16xf16>
// CHECK-COUNT-2: xegpu.load_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>> -> vector<2x16x16xf16>
%b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16>
%b_transpose = xetile.transpose %b_value, [1, 0] : vector<32x32xf16> -> vector<32x32xf16>
%c_new_value = xetile.tile_mma %a_value, %b_transpose, %c_value : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32>
Expand Down Expand Up @@ -56,14 +56,14 @@ gpu.module @test_kernel {
%c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32>
%c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<32x32xf32> -> vector<32x32xf32>
%a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>
// CHECK-COUNT-2: %{{.*}} = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>
%b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16>
// CHECK: scf.for %{{.*}}= %{{.*}}to %{{.*}}step %{{.*}}iter_args(%{{.*}}= %{{.*}}, %[[ARG5:.*]] = %[[T1]], %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}} = %{{.*}}) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
// CHECK: %{{.*}}:11 = scf.for {{.*}} -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
%out:3 = scf.for %k = %c0 to %c1024 step %c32
iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value)
-> (!xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>) {
%a_value = xetile.load_tile %a_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16>
// xegpu.load_nd %[[ARG5]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>> -> vector<32x16xf16>
// CHECK-COUNT-2: xegpu.load_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 2 : i64, boundary_check = true>> -> vector<2x16x16xf16>
%b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16>
%b_transpose = xetile.transpose %b_value, [1, 0] : vector<32x32xf16> -> vector<32x32xf16>
%preop = math.exp %b_transpose : vector<32x32xf16>
Expand Down
12 changes: 8 additions & 4 deletions test/Conversion/XeTileToXeGPU/test_order.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[R_CAST:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<64x128xf16, strided<[1, 64]>> to memref<128x64xf16, strided<[64, 1]>>
// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C0]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T8:.*]] = xegpu.load_nd %[[T1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x16xf16>
// CHECK: %[[T19:.*]] = xegpu.update_nd_offset %[[T1]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T26:.*]] = xegpu.load_nd %[[T19]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<32x16xf16>
// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C0]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C16]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T8:.*]] = xegpu.load_nd %[[T1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<16x16xf16>
// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T2]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<16x16xf16>
// CHECK: %[[T19:.*]] = xegpu.update_nd_offset %[[T1]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T20:.*]] = xegpu.update_nd_offset %[[T2]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T26:.*]] = xegpu.load_nd %[[T19]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<16x16xf16>
// CHECK: %[[T27:.*]] = xegpu.load_nd %[[T20]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<16x16xf16>
gpu.module @test_kernel {
func.func @test_func(%A : memref<128x64xf16>, %B : memref<64x128xf16, strided<[1, 64], offset: 0>>) {
%c0 = arith.constant 0 : index
Expand Down
22 changes: 10 additions & 12 deletions test/Dialect/XeTile/Transforms/Blocking/unit_tests.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,15 @@ gpu.module @test_kernel {
%24 = index.remu %20, %c1
%25 = index.mul %24, %c32

// CHECK: xetile.init_tile %{{.*}} : memref<1536x12288xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr<inner_blocks = [32, 16]>>
// CHECK: xetile.init_tile %{{.*}} : memref<1536x12288xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr<inner_blocks = [16, 16]>>
%26 = xetile.init_tile %arg1[%23, %25] : memref<1536x12288xf16> -> !xetile.tile<64x32xf16>
%27:2 = scf.for %arg15 = %c0 to %c2 step %c1 iter_args(%arg16 = %15, %arg17 = %18) -> (!xetile.tile<32x64xf32>, !xetile.tile<32x32xf16>) {
//CHECK: xetile.update_tile_offset %{{.*}}, [%c1024, %c0] : !xetile.tile<32x32xf16, #xetile.tile_attr<inner_blocks = [32, 16]>>
//CHECK: xetile.update_tile_offset %{{.*}}, [%c1024, %c0] : !xetile.tile<32x64xf32, #xetile.tile_attr<inner_blocks = [8, 16]>>
%28 = xetile.update_tile_offset %arg17, [%c1024, %c0] : !xetile.tile<32x32xf16>
%29 = xetile.update_tile_offset %arg16, [%c1024, %c0] : !xetile.tile<32x64xf32>
%30:3 = scf.for %arg18 = %c0 to %c12288 step %c32 iter_args(%arg19 = %cst, %arg20 = %arg17, %arg21 = %26) -> (vector<32x64xf32>, !xetile.tile<32x32xf16>, !xetile.tile<64x32xf16>) {
//CHECK: xetile.update_tile_offset %{{.*}}, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr<inner_blocks = [32, 16]>>
//CHECK: xetile.update_tile_offset %{{.*}}, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr<inner_blocks = [16, 16]>>
//CHECK: xetile.update_tile_offset %{{.*}}, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr<inner_blocks = [32, 16]>>
%32 = xetile.update_tile_offset %arg21, [%c0, %c32] : !xetile.tile<64x32xf16>
%33 = xetile.update_tile_offset %arg20, [%c0, %c32] : !xetile.tile<32x32xf16>
Expand Down Expand Up @@ -524,16 +524,14 @@ gpu.module @test_kernel {
%5 = xetile.init_tile %arg1[0, 0] : memref<256x384xf32> -> !xetile.tile<64x32xf32>
xetile.store_tile %4, %5 : vector<64x32xf32>, !xetile.tile<64x32xf32>

//CHECK: %[[r0:.*]] = xetile.init_tile %{{.*}}[0, 0] : memref<384x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr<inner_blocks = [32, 1]>>
//CHECK: %[[r1:.*]] = xetile.load_tile %[[r0]] {padding = 0.000000e+00 : f32} : !xetile.tile<32x1xf32, #xetile.tile_attr<inner_blocks = [32, 1]>> -> vector<1x1x32x1xf32>
//CHECK: %[[r2:.*]] = xetile.tile_unpack %[[r1]] {inner_blocks = array<i64: 32, 1>} : vector<1x1x32x1xf32> -> vector<32x1xf32>
//CHECK: %[[r3:.*]] = xetile.tile_pack %[[r2]] {inner_blocks = array<i64: 16, 1>} : vector<32x1xf32> -> vector<2x1x16x1xf32>
//CHECK: %[[r4:.*]] = xetile.transpose %[[r3]], [1, 0, 3, 2] : vector<2x1x16x1xf32> -> vector<1x2x1x16xf32>
//CHECK: %[[r5:.*]] = xetile.broadcast %[[r4]] [0, 2] : vector<1x2x1x16xf32> -> vector<64x2x1x16xf32>
//CHECK: %[[r6:.*]] = xetile.tile_unpack %[[r5]] {inner_blocks = array<i64: 1, 16>} : vector<64x2x1x16xf32> -> vector<64x32xf32>
//CHECK: %[[r7:.*]] = xetile.init_tile %{{.*}}[0, 0] : memref<256x384xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr<inner_blocks = [8, 16]>>
//CHECK: %[[r8:.*]] = xetile.tile_pack %[[r6]] {inner_blocks = array<i64: 8, 16>} : vector<64x32xf32> -> vector<8x2x8x16xf32>
//CHECK: xetile.store_tile %[[r8]], %[[r7]] : vector<8x2x8x16xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr<inner_blocks = [8, 16]>>
//CHECK: %[[r0:.*]] = xetile.init_tile %{{.*}}[0, 0] : memref<384x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr<inner_blocks = [16, 1]>>
//CHECK: %[[r1:.*]] = xetile.load_tile %[[r0]] {padding = 0.000000e+00 : f32} : !xetile.tile<32x1xf32, #xetile.tile_attr<inner_blocks = [16, 1]>> -> vector<2x1x16x1xf32>
//CHECK: %[[r2:.*]] = xetile.transpose %[[r1]], [1, 0, 3, 2] : vector<2x1x16x1xf32> -> vector<1x2x1x16xf32>
//CHECK: %[[r3:.*]] = xetile.broadcast %[[r2]] [0, 2] : vector<1x2x1x16xf32> -> vector<64x2x1x16xf32>
//CHECK: %[[r4:.*]] = xetile.tile_unpack %[[r3]] {inner_blocks = array<i64: 1, 16>} : vector<64x2x1x16xf32> -> vector<64x32xf32>
//CHECK: %[[r5:.*]] = xetile.init_tile %{{.*}}[0, 0] : memref<256x384xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr<inner_blocks = [8, 16]>>
//CHECK: %[[r6:.*]] = xetile.tile_pack %[[r4]] {inner_blocks = array<i64: 8, 16>} : vector<64x32xf32> -> vector<8x2x8x16xf32>
//CHECK: xetile.store_tile %[[r6]], %[[r5]] : vector<8x2x8x16xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr<inner_blocks = [8, 16]>>
gpu.return
}

Expand Down
Loading

0 comments on commit 2556aee

Please sign in to comment.