From bee50a07c4c6232b07c111934f986809f8686e51 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Fri, 13 Dec 2024 10:51:45 -0600 Subject: [PATCH] [XeTile-Blocking] Update blocking alingment for loadOp (#982) --- .../XeTile/Transforms/BlockingAnalysis.cpp | 19 ++++--- .../XeTileToXeGPU/sg_gemm_transpose_b.mlir | 12 ++--- test/Conversion/XeTileToXeGPU/test_order.mlir | 12 +++-- .../Transforms/Blocking/unit_tests.mlir | 22 ++++---- .../Blocking/unit_tests_transform.mlir | 51 ++++++++++++++----- 5 files changed, 76 insertions(+), 40 deletions(-) diff --git a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp index be66dab79..53d2a0f80 100644 --- a/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp +++ b/lib/Dialect/XeTile/Transforms/BlockingAnalysis.cpp @@ -406,14 +406,21 @@ void BlockingAnalysisImpl::visitLoadTileOp( // adjust according to user's requirements if it is available if (lattice.isInitialized()) { - // Always align the width dimension. - // NOTE: For transpose usecase, we still align the width dimension. This is - // because loads with transpose cannot have array_length > 1, plus it has HW - // limitations on supported width. If we align the height dimension (for - // reducing reg data movement), it will lead to multiple smaller loads. - for (auto rq : lattice.getRequests()) + bool hasTransposeUser = op.getValue().hasOneUse() && + mlir::isa(*(op->user_begin())); + + // To minimize the in-reg data movement, we need to align dim1 for regular + // case and dim0 for transpose case. For transpose case, we also need to + // make sure dim1 such that the following pass can fold the transpose with + // the load. + for (auto rq : lattice.getRequests()) { if (rq[1] && ((rq[1] * bitWidth) % 32 == 0)) // has to be 32-bit aligned block[1] = std::min(block[1], rq[1]); + + // also aligns the height dimension if user is a transpose op. + if (hasTransposeUser) + block[0] = std::min(block[0], rq[0]); + } } if (!block) diff --git a/test/Conversion/XeTileToXeGPU/sg_gemm_transpose_b.mlir b/test/Conversion/XeTileToXeGPU/sg_gemm_transpose_b.mlir index 66f427992..2553f3946 100644 --- a/test/Conversion/XeTileToXeGPU/sg_gemm_transpose_b.mlir +++ b/test/Conversion/XeTileToXeGPU/sg_gemm_transpose_b.mlir @@ -17,14 +17,14 @@ gpu.module @test_kernel { %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<32x32xf32> -> vector<32x32xf32> %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> -// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> +// CHECk-COUNT-2: %{{.*}} = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> -// CHECK: scf.for %{{.*}}= %{{.*}}to %{{.*}}step %{{.*}}iter_args(%{{.*}}= %{{.*}}, %[[ARG5:.*]] = %[[T1]], %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}} = %{{.*}}) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { +// CHECK: %{{.*}}:11 = scf.for {{.*}} -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { %out:3 = scf.for %k = %c0 to %c1024 step %c32 iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) -> (!xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>) { %a_value = xetile.load_tile %a_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> -// CHECK: xegpu.load_nd %[[ARG5]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> +// CHECK-COUNT-2: xegpu.load_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xf16> %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> %b_transpose = xetile.transpose %b_value, [1, 0] : vector<32x32xf16> -> vector<32x32xf16> %c_new_value = xetile.tile_mma %a_value, %b_transpose, %c_value : vector<32x32xf16>, vector<32x32xf16>, vector<32x32xf32> -> vector<32x32xf32> @@ -56,14 +56,14 @@ gpu.module @test_kernel { %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<32x32xf32> -> vector<32x32xf32> %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> -// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> +// CHECK-COUNT-2: %{{.*}} = xegpu.create_nd_tdesc %[[arg1]][%{{.*}}, %{{.*}}] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<32x32xf16> -// CHECK: scf.for %{{.*}}= %{{.*}}to %{{.*}}step %{{.*}}iter_args(%{{.*}}= %{{.*}}, %[[ARG5:.*]] = %[[T1]], %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}}= %{{.*}}, %{{.*}} = %{{.*}}) -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { +// CHECK: %{{.*}}:11 = scf.for {{.*}} -> (!xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { %out:3 = scf.for %k = %c0 to %c1024 step %c32 iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) -> (!xetile.tile<32x32xf16>, !xetile.tile<32x32xf16>, vector<32x32xf32>) { %a_value = xetile.load_tile %a_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> -// xegpu.load_nd %[[ARG5]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<32x16xf16> +// CHECK-COUNT-2: xegpu.load_nd %{{.*}} <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xf16> %b_value = xetile.load_tile %b_tile : !xetile.tile<32x32xf16> -> vector<32x32xf16> %b_transpose = xetile.transpose %b_value, [1, 0] : vector<32x32xf16> -> vector<32x32xf16> %preop = math.exp %b_transpose : vector<32x32xf16> diff --git a/test/Conversion/XeTileToXeGPU/test_order.mlir b/test/Conversion/XeTileToXeGPU/test_order.mlir index e6873a6e4..7bb5b861d 100644 --- a/test/Conversion/XeTileToXeGPU/test_order.mlir +++ b/test/Conversion/XeTileToXeGPU/test_order.mlir @@ -5,10 +5,14 @@ // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C16:.*]] = arith.constant 16 : index // CHECK: %[[R_CAST:.*]] = memref.reinterpret_cast %[[ARG1]] to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<64x128xf16, strided<[1, 64]>> to memref<128x64xf16, strided<[64, 1]>> -// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C0]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -// CHECK: %[[T8:.*]] = xegpu.load_nd %[[T1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<32x16xf16> -// CHECK: %[[T19:.*]] = xegpu.update_nd_offset %[[T1]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -// CHECK: %[[T26:.*]] = xegpu.load_nd %[[T19]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<32x16xf16> +// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C0]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[R_CAST]][%[[C16]], %[[C0]]] : memref<128x64xf16, strided<[64, 1]>> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T8:.*]] = xegpu.load_nd %[[T1]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> +// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T2]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> +// CHECK: %[[T19:.*]] = xegpu.update_nd_offset %[[T1]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T20:.*]] = xegpu.update_nd_offset %[[T2]], [%[[C0]], %[[C16]]] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T26:.*]] = xegpu.load_nd %[[T19]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> +// CHECK: %[[T27:.*]] = xegpu.load_nd %[[T20]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> gpu.module @test_kernel { func.func @test_func(%A : memref<128x64xf16>, %B : memref<64x128xf16, strided<[1, 64], offset: 0>>) { %c0 = arith.constant 0 : index diff --git a/test/Dialect/XeTile/Transforms/Blocking/unit_tests.mlir b/test/Dialect/XeTile/Transforms/Blocking/unit_tests.mlir index 0f057a197..3525254cf 100644 --- a/test/Dialect/XeTile/Transforms/Blocking/unit_tests.mlir +++ b/test/Dialect/XeTile/Transforms/Blocking/unit_tests.mlir @@ -486,7 +486,7 @@ gpu.module @test_kernel { %24 = index.remu %20, %c1 %25 = index.mul %24, %c32 - // CHECK: xetile.init_tile %{{.*}} : memref<1536x12288xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr> + // CHECK: xetile.init_tile %{{.*}} : memref<1536x12288xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr> %26 = xetile.init_tile %arg1[%23, %25] : memref<1536x12288xf16> -> !xetile.tile<64x32xf16> %27:2 = scf.for %arg15 = %c0 to %c2 step %c1 iter_args(%arg16 = %15, %arg17 = %18) -> (!xetile.tile<32x64xf32>, !xetile.tile<32x32xf16>) { //CHECK: xetile.update_tile_offset %{{.*}}, [%c1024, %c0] : !xetile.tile<32x32xf16, #xetile.tile_attr> @@ -494,7 +494,7 @@ gpu.module @test_kernel { %28 = xetile.update_tile_offset %arg17, [%c1024, %c0] : !xetile.tile<32x32xf16> %29 = xetile.update_tile_offset %arg16, [%c1024, %c0] : !xetile.tile<32x64xf32> %30:3 = scf.for %arg18 = %c0 to %c12288 step %c32 iter_args(%arg19 = %cst, %arg20 = %arg17, %arg21 = %26) -> (vector<32x64xf32>, !xetile.tile<32x32xf16>, !xetile.tile<64x32xf16>) { - //CHECK: xetile.update_tile_offset %{{.*}}, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr> + //CHECK: xetile.update_tile_offset %{{.*}}, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr> //CHECK: xetile.update_tile_offset %{{.*}}, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr> %32 = xetile.update_tile_offset %arg21, [%c0, %c32] : !xetile.tile<64x32xf16> %33 = xetile.update_tile_offset %arg20, [%c0, %c32] : !xetile.tile<32x32xf16> @@ -524,16 +524,14 @@ gpu.module @test_kernel { %5 = xetile.init_tile %arg1[0, 0] : memref<256x384xf32> -> !xetile.tile<64x32xf32> xetile.store_tile %4, %5 : vector<64x32xf32>, !xetile.tile<64x32xf32> - //CHECK: %[[r0:.*]] = xetile.init_tile %{{.*}}[0, 0] : memref<384x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> - //CHECK: %[[r1:.*]] = xetile.load_tile %[[r0]] {padding = 0.000000e+00 : f32} : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<1x1x32x1xf32> - //CHECK: %[[r2:.*]] = xetile.tile_unpack %[[r1]] {inner_blocks = array} : vector<1x1x32x1xf32> -> vector<32x1xf32> - //CHECK: %[[r3:.*]] = xetile.tile_pack %[[r2]] {inner_blocks = array} : vector<32x1xf32> -> vector<2x1x16x1xf32> - //CHECK: %[[r4:.*]] = xetile.transpose %[[r3]], [1, 0, 3, 2] : vector<2x1x16x1xf32> -> vector<1x2x1x16xf32> - //CHECK: %[[r5:.*]] = xetile.broadcast %[[r4]] [0, 2] : vector<1x2x1x16xf32> -> vector<64x2x1x16xf32> - //CHECK: %[[r6:.*]] = xetile.tile_unpack %[[r5]] {inner_blocks = array} : vector<64x2x1x16xf32> -> vector<64x32xf32> - //CHECK: %[[r7:.*]] = xetile.init_tile %{{.*}}[0, 0] : memref<256x384xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr> - //CHECK: %[[r8:.*]] = xetile.tile_pack %[[r6]] {inner_blocks = array} : vector<64x32xf32> -> vector<8x2x8x16xf32> - //CHECK: xetile.store_tile %[[r8]], %[[r7]] : vector<8x2x8x16xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr> + //CHECK: %[[r0:.*]] = xetile.init_tile %{{.*}}[0, 0] : memref<384x1xf32> -> !xetile.tile<32x1xf32, #xetile.tile_attr> + //CHECK: %[[r1:.*]] = xetile.load_tile %[[r0]] {padding = 0.000000e+00 : f32} : !xetile.tile<32x1xf32, #xetile.tile_attr> -> vector<2x1x16x1xf32> + //CHECK: %[[r2:.*]] = xetile.transpose %[[r1]], [1, 0, 3, 2] : vector<2x1x16x1xf32> -> vector<1x2x1x16xf32> + //CHECK: %[[r3:.*]] = xetile.broadcast %[[r2]] [0, 2] : vector<1x2x1x16xf32> -> vector<64x2x1x16xf32> + //CHECK: %[[r4:.*]] = xetile.tile_unpack %[[r3]] {inner_blocks = array} : vector<64x2x1x16xf32> -> vector<64x32xf32> + //CHECK: %[[r5:.*]] = xetile.init_tile %{{.*}}[0, 0] : memref<256x384xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr> + //CHECK: %[[r6:.*]] = xetile.tile_pack %[[r4]] {inner_blocks = array} : vector<64x32xf32> -> vector<8x2x8x16xf32> + //CHECK: xetile.store_tile %[[r6]], %[[r5]] : vector<8x2x8x16xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr> gpu.return } diff --git a/test/Dialect/XeTile/Transforms/Blocking/unit_tests_transform.mlir b/test/Dialect/XeTile/Transforms/Blocking/unit_tests_transform.mlir index 3b52485ff..ec137fb71 100644 --- a/test/Dialect/XeTile/Transforms/Blocking/unit_tests_transform.mlir +++ b/test/Dialect/XeTile/Transforms/Blocking/unit_tests_transform.mlir @@ -111,6 +111,29 @@ gpu.module @test_kernel { gpu.return } + //CHECK: gpu.func @sg_tile_mma_b_transpose(%[[arg0:.*]]: memref<64x32xf16>, %[[arg1:.*]]: memref<64x32xf16>, %[[arg2:.*]]: memref<64x64xf32>) + gpu.func @sg_tile_mma_b_transpose(%a: memref<64x32xf16>, %b: memref<64x32xf16>, %c: memref<64x64xf32>) { + //CHECK-COUNT-4: %{{.*}} = xetile.init_tile %[[arg0]][%{{.*}}] : memref<64x32xf16> -> !xetile.tile<32x16xf16> + %0 = xetile.init_tile %a[0, 0] : memref<64x32xf16> -> !xetile.tile<64x32xf16> + //CHECK-COUNT-4: %{{.*}} = xetile.load_tile %{{.*}} : !xetile.tile<32x16xf16> -> vector<32x16xf16> + //CHECK-COUNT-16: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> + %1 = xetile.load_tile %0 : !xetile.tile<64x32xf16> -> vector<64x32xf16> + + //CHECK-COUNT-8: %{{.*}} = xetile.init_tile %[[arg1]][{{.*}}] : memref<64x32xf16> -> !xetile.tile<16x16xf16> + %2 = xetile.init_tile %b[0, 0] : memref<64x32xf16> -> !xetile.tile<64x32xf16> + //CHECK-COUNT-8: %{{.*}} = xetile.load_tile %{{.*}} : !xetile.tile<16x16xf16> -> vector<16x16xf16> + %3 = xetile.load_tile %2 : !xetile.tile<64x32xf16> -> vector<64x32xf16> + //CHECK-COUNT-8: %{{.*}} = xetile.transpose %{{.*}}, [1, 0] : vector<16x16xf16> -> vector<16x16xf16> + %4 = xetile.transpose %3, [1, 0] : vector<64x32xf16> -> vector<32x64xf16> + //CHECK-COUNT-64: %{{.*}} = xetile.tile_mma {{.*}} : vector<8x16xf16>, vector<16x16xf16>{{.*}}-> vector<8x16xf32> + %5 = xetile.tile_mma %1, %4: vector<64x32xf16>, vector<32x64xf16> -> vector<64x64xf32> + //CHECK-COUNT-32: xetile.init_tile %[[arg2]][{{.*}}] : memref<64x64xf32> -> !xetile.tile<8x16xf32> + %6 = xetile.init_tile %c[0, 0] : memref<64x64xf32> -> !xetile.tile<64x64xf32> + //CHECK-COUNT-32: xetile.store_tile %{{.*}}, %{{.*}} : vector<8x16xf32>, !xetile.tile<8x16xf32> + xetile.store_tile %5, %6: vector<64x64xf32>, !xetile.tile<64x64xf32> + gpu.return + } + // CHECK-LABEL: gpu.func @inner_reduction // CHECK-SAME: (%[[arg0:.*]]: memref<128x256xf16>, %[[arg1:.*]]: memref<128x256xf16>) gpu.func @inner_reduction(%a: memref<128x256xf16>, %b: memref<128x256xf16>) { @@ -1547,8 +1570,8 @@ gpu.module @test_kernel { %15 = xetile.init_tile %arg2[%11, %14] : memref<16384x1536xf32> -> !xetile.tile<32x64xf32> %16 = index.remu %8, %c1 %17 = index.mul %16, %c32 - //CHECK: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<1536x12288xf16> -> !xetile.tile<32x16xf16> - //CHECK: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<1536x12288xf16> -> !xetile.tile<32x16xf16> + //CHECK: %{{.*}} = xetile.init_tile %[[arg0]][%{{.*}}, %{{.*}}] : memref<16384x12288xf16> -> !xetile.tile<32x16xf16> + //CHECK: %{{.*}} = xetile.init_tile %[[arg0]][%{{.*}}, %{{.*}}] : memref<16384x12288xf16> -> !xetile.tile<32x16xf16> %18 = xetile.init_tile %arg0[%11, %17] : memref<16384x12288xf16> -> !xetile.tile<32x32xf16> %19 = index.floordivs %6, %c8 %20 = index.remu %6, %c8 @@ -1557,26 +1580,30 @@ gpu.module @test_kernel { %23 = index.add %2, %22 %24 = index.remu %20, %c1 %25 = index.mul %24, %c32 - //CHECK-COUNT-2: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<1536x12288xf16> -> !xetile.tile<32x16xf16> + //CHECK: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<1536x12288xf16> -> !xetile.tile<16x16xf16> + //CHECK: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<1536x12288xf16> -> !xetile.tile<16x16xf16> + //CHECK-COUNT-2: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<1536x12288xf16> -> !xetile.tile<16x16xf16> + //CHECK-COUNT-2: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<1536x12288xf16> -> !xetile.tile<16x16xf16> + //CHECK-COUNT-2: %{{.*}} = xetile.init_tile %[[arg1]][%{{.*}}, %{{.*}}] : memref<1536x12288xf16> -> !xetile.tile<16x16xf16> %26 = xetile.init_tile %arg1[%23, %25] : memref<1536x12288xf16> -> !xetile.tile<64x32xf16> %27:2 = scf.for %arg15 = %c0 to %c2 step %c1 iter_args(%arg16 = %15, %arg17 = %18) -> (!xetile.tile<32x64xf32>, !xetile.tile<32x32xf16>) { //CHECK-COUNT-2: %{{.*}} = xetile.update_tile_offset %{{.*}}, [%{{.*}}, %{{.*}}] : !xetile.tile<32x16xf16> - //CHECK-COUNT-16: %{{.*}} = xetile.update_tile_offset %{{.*}}, [%{{.*}}, %{{.*}}] : !xetile.tile<8x16xf32> %28 = xetile.update_tile_offset %arg17, [%c1024, %c0] : !xetile.tile<32x32xf16> + //CHECK-COUNT-16: %{{.*}} = xetile.update_tile_offset %{{.*}}, [%{{.*}}, %{{.*}}] : !xetile.tile<8x16xf32> %29 = xetile.update_tile_offset %arg16, [%c1024, %c0] : !xetile.tile<32x64xf32> - //CHECK: %{{.*}}:22 = scf.for %[[arg22:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args({{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xetile.tile<32x16xf16>, !xetile.tile<32x16xf16>, !xetile.tile<32x16xf16>, !xetile.tile<32x16xf16>, !xetile.tile<32x16xf16>, !xetile.tile<32x16xf16>) { + //CHECK: %{{.*}}:26 = scf.for %[[arg22:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args({{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, !xetile.tile<32x16xf16>, !xetile.tile<32x16xf16>, !xetile.tile<16x16xf16>, !xetile.tile<16x16xf16>, !xetile.tile<16x16xf16>, !xetile.tile<16x16xf16>, !xetile.tile<16x16xf16>, !xetile.tile<16x16xf16>, !xetile.tile<16x16xf16>, !xetile.tile<16x16xf16>) { %30:3 = scf.for %arg18 = %c0 to %c12288 step %c32 iter_args(%arg19 = %cst, %arg20 = %arg17, %arg21 = %26) -> (vector<32x64xf32>, !xetile.tile<32x32xf16>, !xetile.tile<64x32xf16>) { - //CHECK-COUNT-6: %{{.*}} = xetile.update_tile_offset %{{.*}}, [%{{.*}}, %{{.*}}] : !xetile.tile<32x16xf16> + //CHECK-COUNT-8: %{{.*}} = xetile.update_tile_offset %{{.*}}, [%{{.*}}, %{{.*}}] : !xetile.tile<16x16xf16> %32 = xetile.update_tile_offset %arg21, [%c0, %c32] : !xetile.tile<64x32xf16> + //CHECK-COUNT-2: %{{.*}} = xetile.update_tile_offset %{{.*}}, [%{{.*}}, %{{.*}}] : !xetile.tile<32x16xf16> %33 = xetile.update_tile_offset %arg20, [%c0, %c32] : !xetile.tile<32x32xf16> //CHECK-COUNT-2: %{{.*}} = xetile.load_tile %{{.*}} {padding = 0.000000e+00 : f32} : !xetile.tile<32x16xf16> -> vector<32x16xf16> //CHECK-COUNT-8: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [8, 16], strides = [1, 1]} : vector<32x16xf16> to vector<8x16xf16> %34 = xetile.load_tile %arg20 {padding = 0.000000e+00 : f32} : !xetile.tile<32x32xf16> -> vector<32x32xf16> //CHECK-COUNT-8: %{{.*}} = math.exp %{{.*}} : vector<8x16xf16> %35 = math.exp %34 : vector<32x32xf16> - //CHECK-COUNT-4: %{{.*}} = xetile.load_tile %{{.*}} {padding = 0.000000e+00 : f32} : !xetile.tile<32x16xf16> -> vector<32x16xf16> + //CHECK-COUNT-8: %{{.*}} = xetile.load_tile %{{.*}} {padding = 0.000000e+00 : f32} : !xetile.tile<16x16xf16> -> vector<16x16xf16> %36 = xetile.load_tile %arg21 {padding = 0.000000e+00 : f32} : !xetile.tile<64x32xf16> -> vector<64x32xf16> - //CHECK-COUNT-8: %{{.*}} = vector.extract_strided_slice %{{.*}} {offsets = [{{.*}}], sizes = [16, 16], strides = [1, 1]} : vector<32x16xf16> to vector<16x16xf16> //CHECK-COUNT-8: %{{.*}} = xetile.transpose %{{.*}}, [1, 0] : vector<16x16xf16> -> vector<16x16xf16> %37 = xetile.transpose %36, [1, 0] : vector<64x32xf16> -> vector<32x64xf16> //CHECK-COUNT-8: %{{.*}} = math.exp %{{.*}} : vector<16x16xf16> @@ -1603,12 +1630,12 @@ gpu.module @test_kernel { //CHECK-LABEL: gpu.func @sglevel_transpose_broadcast_dim_0 //CHECK-SAME(%[[arg0:.*]]: memref<384x1xf32>, %[[arg1:.*]]: memref<256x384xf32>) gpu.func @sglevel_transpose_broadcast_dim_0(%arg0: memref<384x1xf32>, %arg1: memref<256x384xf32>) { - //CHECK: %[[r0:.*]] = xetile.init_tile %[[arg0]][0, 0] : memref<384x1xf32> -> !xetile.tile<32x1xf32> + //CHECK: %[[r0:.*]] = xetile.init_tile %[[arg0]][{{.*}}] : memref<384x1xf32> -> !xetile.tile<16x1xf32> + //CHECK: %[[r1:.*]] = xetile.init_tile %[[arg0]][{{.*}}] : memref<384x1xf32> -> !xetile.tile<16x1xf32> %1 = xetile.init_tile %arg0[0, 0] : memref<384x1xf32> -> !xetile.tile<32x1xf32> - //CHECK: %[[r1:.*]] = xetile.load_tile %[[r0]] {padding = 0.000000e+00 : f32} : !xetile.tile<32x1xf32> -> vector<32x1xf32> + //CHECK: %[[r2:.*]] = xetile.load_tile %[[r0]] {padding = 0.000000e+00 : f32} : !xetile.tile<16x1xf32> -> vector<16x1xf32> + //CHECK: %[[r3:.*]] = xetile.load_tile %[[r1]] {padding = 0.000000e+00 : f32} : !xetile.tile<16x1xf32> -> vector<16x1xf32> %2 = xetile.load_tile %1 {padding = 0.000000e+00 : f32} : !xetile.tile<32x1xf32> -> vector<32x1xf32> - //CHECK: %[[r2:.*]] = vector.extract_strided_slice %[[r1]] {offsets = [0, 0], sizes = [16, 1], strides = [1, 1]} : vector<32x1xf32> to vector<16x1xf32> - //CHECK: %[[r3:.*]] = vector.extract_strided_slice %[[r1]] {offsets = [16, 0], sizes = [16, 1], strides = [1, 1]} : vector<32x1xf32> to vector<16x1xf32> //CHECK: %[[r4:.*]] = xetile.transpose %[[r2]], [1, 0] : vector<16x1xf32> -> vector<1x16xf32> //CHECK: %[[r5:.*]] = xetile.transpose %[[r3]], [1, 0] : vector<16x1xf32> -> vector<1x16xf32> //CHECK: %[[r6:.*]] = vector.shape_cast %[[r4]] : vector<1x16xf32> to vector<16xf32>