From d8a52ff41794253ce4de400d05b5e38867636a42 Mon Sep 17 00:00:00 2001 From: Nishant Patel Date: Thu, 14 Nov 2024 13:11:04 -0800 Subject: [PATCH] Support init_tile with nD memrefs in wg to sg pass (#961) * Support init_tile with nD memrefs in wg to sg pass * Fix pre-commit --- lib/Dialect/XeTile/Transforms/WgToSg.cpp | 19 ++++-- .../broadcast.mlir} | 0 .../btranspose.mlir} | 0 .../convert_layout.mlir} | 0 .../gemm_1k.mlir} | 0 .../gemm_4k.mlir} | 0 .../XeTile/Transforms/WgToSg/gemm_batch.mlir | 63 +++++++++++++++++ .../Transforms/WgToSg/gemm_batch_oob.mlir | 67 +++++++++++++++++++ .../gemm_postop.mlir} | 0 .../prefetch.mlir} | 0 .../round_robin.mlir} | 0 .../Transforms/{ => WgToSg}/unit_tests.mlir | 0 12 files changed, 144 insertions(+), 5 deletions(-) rename test/Dialect/XeTile/Transforms/{wg_to_sg_broadcast.mlir => WgToSg/broadcast.mlir} (100%) rename test/Dialect/XeTile/Transforms/{wg_to_sg_btranspose.mlir => WgToSg/btranspose.mlir} (100%) rename test/Dialect/XeTile/Transforms/{wg_to_sg_convert_layout.mlir => WgToSg/convert_layout.mlir} (100%) rename test/Dialect/XeTile/Transforms/{wg_to_sg_1k_gemm.mlir => WgToSg/gemm_1k.mlir} (100%) rename test/Dialect/XeTile/Transforms/{wg_to_sg_4k_gemm.mlir => WgToSg/gemm_4k.mlir} (100%) create mode 100644 test/Dialect/XeTile/Transforms/WgToSg/gemm_batch.mlir create mode 100644 test/Dialect/XeTile/Transforms/WgToSg/gemm_batch_oob.mlir rename test/Dialect/XeTile/Transforms/{wg_to_sg_gemm_postop.mlir => WgToSg/gemm_postop.mlir} (100%) rename test/Dialect/XeTile/Transforms/{wg_to_sg_prefetch.mlir => WgToSg/prefetch.mlir} (100%) rename test/Dialect/XeTile/Transforms/{wg_to_sg_round_robin.mlir => WgToSg/round_robin.mlir} (100%) rename test/Dialect/XeTile/Transforms/{ => WgToSg}/unit_tests.mlir (100%) diff --git a/lib/Dialect/XeTile/Transforms/WgToSg.cpp b/lib/Dialect/XeTile/Transforms/WgToSg.cpp index 7fed049e5..1b1656783 100644 --- a/lib/Dialect/XeTile/Transforms/WgToSg.cpp +++ b/lib/Dialect/XeTile/Transforms/WgToSg.cpp @@ -178,15 +178,15 @@ class WGToSGInitTileOpPattern : public XeOneToNConversion { if (it != opSgLayoutMap.end()){ assert((opSgLayoutMap[op->getResult(0)] == std::array{0, 1})); calculateGlobalOffsets(globalOffsetsY, wgTileShape[0], sgTileShape[0], - sgLayout[0], sgDataDimYConst, sgIdX, offsets[0]); + sgLayout[0], sgDataDimYConst, sgIdX, offsets[offsets.size() - 2]); calculateGlobalOffsets(globalOffsetsX, wgTileShape[1], sgTileShape[1], - sgLayout[1], sgDataDimXConst, sgIdY, offsets[1]); + sgLayout[1], sgDataDimXConst, sgIdY, offsets[offsets.size() - 1]); } else { calculateGlobalOffsets(globalOffsetsY, wgTileShape[0], sgTileShape[0], - sgLayout[0], sgDataDimYConst, sgIdY, offsets[0]); + sgLayout[0], sgDataDimYConst, sgIdY, offsets[offsets.size() - 2]); calculateGlobalOffsets(globalOffsetsX, wgTileShape[1], sgTileShape[1], - sgLayout[1], sgDataDimXConst, sgIdX, offsets[1]); + sgLayout[1], sgDataDimXConst, sgIdX, offsets[offsets.size() - 1]); } // TODO: check for how to broadcast for (auto y : globalOffsetsY) { @@ -197,9 +197,16 @@ class WGToSGInitTileOpPattern : public XeOneToNConversion { mlir::SmallVector newInitTileOps; llvm::SmallVector newResultTypes; + llvm::SmallVector newOffsets; + for (size_t j = 0; j < offsets.size() - 2; ++j) { + newOffsets.push_back(offsets[j]); + } for (size_t i = 0; i < offsetPermutations.size(); i++) { + newOffsets.push_back(offsetPermutations[i][0]); + newOffsets.push_back(offsetPermutations[i][1]); auto newInitTileOp = rewriter.create( - loc, newTileTy, source, offsetPermutations[i]); + loc, newTileTy, source, newOffsets); + newOffsets.clear(); newInitTileOps.push_back(newInitTileOp); newResultTypes.push_back(newTileTy); } @@ -988,6 +995,8 @@ class XeTileWgToSgPass target.addDynamicallyLegalOp( [&](mlir::scf::ForOp op) -> bool { + if(op.getInitArgs().empty()) + return true; for (auto arg : op.getInitArgs()) { auto tileTy = mlir::dyn_cast(arg.getType()); if (!tileTy) diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_broadcast.mlir b/test/Dialect/XeTile/Transforms/WgToSg/broadcast.mlir similarity index 100% rename from test/Dialect/XeTile/Transforms/wg_to_sg_broadcast.mlir rename to test/Dialect/XeTile/Transforms/WgToSg/broadcast.mlir diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_btranspose.mlir b/test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir similarity index 100% rename from test/Dialect/XeTile/Transforms/wg_to_sg_btranspose.mlir rename to test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_convert_layout.mlir b/test/Dialect/XeTile/Transforms/WgToSg/convert_layout.mlir similarity index 100% rename from test/Dialect/XeTile/Transforms/wg_to_sg_convert_layout.mlir rename to test/Dialect/XeTile/Transforms/WgToSg/convert_layout.mlir diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_1k_gemm.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_1k.mlir similarity index 100% rename from test/Dialect/XeTile/Transforms/wg_to_sg_1k_gemm.mlir rename to test/Dialect/XeTile/Transforms/WgToSg/gemm_1k.mlir diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_4k_gemm.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_4k.mlir similarity index 100% rename from test/Dialect/XeTile/Transforms/wg_to_sg_4k_gemm.mlir rename to test/Dialect/XeTile/Transforms/WgToSg/gemm_4k.mlir diff --git a/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch.mlir new file mode 100644 index 000000000..6bb1df0ec --- /dev/null +++ b/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch.mlir @@ -0,0 +1,63 @@ +// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s + +module attributes {gpu.container_module} { + func.func @tiles_b_2_entry(%arg0: memref<4x3x2x128x96xf16>, %arg1: memref<4x3x2x64x96xf16>, %arg2: memref<4x3x2x128x64xf32>) attributes {gemm_tiles_b = 2 : i64, gemm_tiles_x = dense<[2, 1, 1, 2]> : vector<4xi64>, gemm_tiles_y = dense<[2, 1, 1, 1]> : vector<4xi64>, habana_runner.num_inputs = 2 : i64, habana_runner.tests = [{inputs = [dense<1.000000e+00> : tensor<4x3x2x128x96xf16>, dense<1.000000e+00> : tensor<4x3x2x64x96xf16>], outputs = [dense<9.600000e+01> : tensor<4x3x2x128x64xf32>]}], physical_nd_range = dense<[8, 1]> : vector<2xi64>, region_partition = 0 : i64, region_size = 1 : i64, syn.fusion_successful, syn.tensor_signature = (tensor<4x3x2x128x96xf16>, tensor<4x3x2x64x96xf16>) -> tensor<4x3x2x128x64xf32>, synFusionGenOps = 9 : i64, synFusionRequiredBeamSize = 1 : i64, synFusionTotalCost = 1000021695.96 : f64} { + %c8 = arith.constant 8 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + gpu.launch_func @tiles_b_2::@tiles_b_2 blocks in (%c8, %c1, %c1) threads in (%c2, %c1, %c1) args(%arg0 : memref<4x3x2x128x96xf16>, %arg1 : memref<4x3x2x64x96xf16>, %arg2 : memref<4x3x2x128x64xf32>) + return + } + gpu.module @tiles_b_2 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @tiles_b_2(%arg0: memref<4x3x2x128x96xf16>, %arg1: memref<4x3x2x64x96xf16>, %arg2: memref<4x3x2x128x64xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c96 = arith.constant 96 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %cst = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<64x32xf32> + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %c12 = arith.constant 12 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c2 = arith.constant 2 : index + %block_id_x = gpu.block_id x + %0 = arith.remsi %block_id_x, %c4 : index + %1 = arith.divsi %0, %c2 : index + %2 = arith.remsi %0, %c2 : index + %3 = arith.muli %block_id_x, %c2 : index + %4 = arith.divsi %3, %c8 : index + %5 = arith.muli %4, %c12 : index + %6 = arith.muli %1, %c64 : index + %7 = arith.muli %2, %c32 : index + scf.for %arg3 = %c0 to %c12 step %c1 { + %8 = arith.addi %5, %arg3 : index + %9 = arith.divsi %8, %c6 : index + %10 = arith.remsi %9, %c4 : index + %11 = arith.divsi %8, %c2 : index + %12 = arith.remsi %11, %c3 : index + %13 = arith.remsi %8, %c2 : index + //CHECK: %[[INITTILE:.*]] = xetile.init_tile {{%.*}}[{{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}] : memref<4x3x2x128x96xf16> -> !xetile.tile<32x32xf16> + //CHECK: %[[INITTILE:.*]] = xetile.init_tile {{%.*}}[{{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}] : memref<4x3x2x64x96xf16> -> !xetile.tile<32x32xf16> + %14 = xetile.init_tile %arg0[%10, %12, %13, %6, %c0] : memref<4x3x2x128x96xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %15 = xetile.init_tile %arg1[%10, %12, %13, %7, %c0] : memref<4x3x2x64x96xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %16:3 = scf.for %arg4 = %c0 to %c96 step %c32 iter_args(%arg5 = %cst, %arg6 = %14, %arg7 = %15) -> (vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) { + %18 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %19 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %20 = xetile.load_tile %arg6 : !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<64x32xf16> + %21 = xetile.load_tile %arg7 : !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<32x32xf16> + %22 = vector.transpose %21, [1, 0] {map = #xetile.wg_map} : vector<32x32xf16> to vector<32x32xf16> + xegpu.compile_hint + %23 = xetile.tile_mma %20, %22, %cst {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<64x32xf16>, vector<32x32xf16>, vector<64x32xf32> -> vector<64x32xf32> + xegpu.compile_hint + %24 = arith.addf %arg5, %23 {map = #xetile.wg_map} : vector<64x32xf32> + scf.yield %24, %19, %18 : vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + } + %17 = xetile.init_tile %arg2[%10, %12, %13, %6, %7] : memref<4x3x2x128x64xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + xetile.store_tile %16#0, %17 : vector<64x32xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + } + gpu.return + } + } +} diff --git a/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch_oob.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch_oob.mlir new file mode 100644 index 000000000..1d15cea31 --- /dev/null +++ b/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch_oob.mlir @@ -0,0 +1,67 @@ +// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s + +module attributes {gpu.container_module} { + func.func @tiles_b_4_oob_entry(%arg0: memref<2x3x3x128x96xf16>, %arg1: memref<2x3x3x64x96xf16>, %arg2: memref<2x3x3x128x64xf32>) attributes {gemm_tiles_b = 4 : i64, gemm_tiles_x = dense<[2, 1, 1, 2]> : vector<4xi64>, gemm_tiles_y = dense<[1, 1, 2, 1]> : vector<4xi64>, habana_runner.num_inputs = 2 : i64, habana_runner.tests = [{inputs = [dense<1.000000e+00> : tensor<2x3x3x128x96xf16>, dense<1.000000e+00> : tensor<2x3x3x64x96xf16>], outputs = [dense<9.600000e+01> : tensor<2x3x3x128x64xf32>]}], physical_nd_range = dense<[8, 2]> : vector<2xi64>, region_partition = 0 : i64, region_size = 2 : i64, syn.fusion_successful, syn.tensor_signature = (tensor<2x3x3x128x96xf16>, tensor<2x3x3x64x96xf16>) -> tensor<2x3x3x128x64xf32>, synFusionGenOps = 9 : i64, synFusionRequiredBeamSize = 1 : i64, synFusionTotalCost = 1000016310.36 : f64} { + %c8 = arith.constant 8 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + gpu.launch_func @tiles_b_4_oob::@tiles_b_4_oob blocks in (%c8, %c2, %c1) threads in (%c2, %c1, %c1) args(%arg0 : memref<2x3x3x128x96xf16>, %arg1 : memref<2x3x3x64x96xf16>, %arg2 : memref<2x3x3x128x64xf32>) + return + } + gpu.module @tiles_b_4_oob attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @tiles_b_4_oob(%arg0: memref<2x3x3x128x96xf16>, %arg1: memref<2x3x3x64x96xf16>, %arg2: memref<2x3x3x128x64xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c96 = arith.constant 96 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %cst = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<64x32xf32> + %c3 = arith.constant 3 : index + %c9 = arith.constant 9 : index + %c18 = arith.constant 18 : index + %c5 = arith.constant 5 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c2 = arith.constant 2 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %0 = arith.remsi %block_id_x, %c2 : index + %1 = arith.remsi %block_id_y, %c2 : index + %2 = arith.muli %block_id_x, %c4 : index + %3 = arith.divsi %2, %c8 : index + %4 = arith.muli %3, %c5 : index + %5 = arith.subi %c18, %4 : index + %6 = arith.cmpi sgt, %5, %c5 : index + %7 = arith.select %6, %c5, %5 : index + %8 = arith.muli %0, %c64 : index + %9 = arith.muli %1, %c32 : index + scf.for %arg3 = %c0 to %7 step %c1 { + %10 = arith.addi %4, %arg3 : index + %11 = arith.divsi %10, %c9 : index + %12 = arith.remsi %11, %c2 : index + %13 = arith.divsi %10, %c3 : index + %14 = arith.remsi %13, %c3 : index + %15 = arith.remsi %10, %c3 : index + //CHECK: %[[INITTILE:.*]] = xetile.init_tile {{%.*}}[{{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}] : memref<2x3x3x128x96xf16> -> !xetile.tile<32x32xf16> + //CHECK: %[[INITTILE:.*]] = xetile.init_tile {{%.*}}[{{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}] : memref<2x3x3x64x96xf16> -> !xetile.tile<32x32xf16> + %16 = xetile.init_tile %arg0[%12, %14, %15, %8, %c0] : memref<2x3x3x128x96xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %17 = xetile.init_tile %arg1[%12, %14, %15, %9, %c0] : memref<2x3x3x64x96xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %18:3 = scf.for %arg4 = %c0 to %c96 step %c32 iter_args(%arg5 = %cst, %arg6 = %16, %arg7 = %17) -> (vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) { + %20 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %21 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %22 = xetile.load_tile %arg6 : !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<64x32xf16> + %23 = xetile.load_tile %arg7 : !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<32x32xf16> + %24 = vector.transpose %23, [1, 0] {map = #xetile.wg_map} : vector<32x32xf16> to vector<32x32xf16> + xegpu.compile_hint + %25 = xetile.tile_mma %22, %24, %cst {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<64x32xf16>, vector<32x32xf16>, vector<64x32xf32> -> vector<64x32xf32> + xegpu.compile_hint + %26 = arith.addf %arg5, %25 {map = #xetile.wg_map} : vector<64x32xf32> + scf.yield %26, %21, %20 : vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + } + %19 = xetile.init_tile %arg2[%12, %14, %15, %8, %9] : memref<2x3x3x128x64xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + xetile.store_tile %18#0, %19 : vector<64x32xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + } + gpu.return + } + } +} diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_gemm_postop.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_postop.mlir similarity index 100% rename from test/Dialect/XeTile/Transforms/wg_to_sg_gemm_postop.mlir rename to test/Dialect/XeTile/Transforms/WgToSg/gemm_postop.mlir diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_prefetch.mlir b/test/Dialect/XeTile/Transforms/WgToSg/prefetch.mlir similarity index 100% rename from test/Dialect/XeTile/Transforms/wg_to_sg_prefetch.mlir rename to test/Dialect/XeTile/Transforms/WgToSg/prefetch.mlir diff --git a/test/Dialect/XeTile/Transforms/wg_to_sg_round_robin.mlir b/test/Dialect/XeTile/Transforms/WgToSg/round_robin.mlir similarity index 100% rename from test/Dialect/XeTile/Transforms/wg_to_sg_round_robin.mlir rename to test/Dialect/XeTile/Transforms/WgToSg/round_robin.mlir diff --git a/test/Dialect/XeTile/Transforms/unit_tests.mlir b/test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir similarity index 100% rename from test/Dialect/XeTile/Transforms/unit_tests.mlir rename to test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir