Skip to content

Commit

Permalink
Add support for prefetch in wg to sg pass (#958)
Browse files Browse the repository at this point in the history
* Add support for prefetch in wg to sg pass

* Fix pre-commit
  • Loading branch information
nbpatel authored Nov 12, 2024
1 parent 0330284 commit 09e7136
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 8 deletions.
51 changes: 43 additions & 8 deletions lib/Dialect/XeTile/Transforms/WgToSg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,27 @@ class WGToSGVectorBroadcast
}
};


class WGToSGPrefetchOpPattern : public XeOneToNConversion<xetile::PrefetchTileOp> {
using XeOneToNConversion<xetile::PrefetchTileOp>::XeOneToNConversion;

mlir::LogicalResult
matchAndRewrite(xetile::PrefetchTileOp op, OpAdaptor adaptor,
XeOneToNPatternRewriter &rewriter) const override {

auto L1 = op.getL1HintAttr();
auto L2 = op.getL2HintAttr();
auto L3 = op.getL3HintAttr();

for(auto tile : adaptor.getTile()) {
rewriter.create<xetile::PrefetchTileOp>(op.getLoc(), tile, L1, L2, L3);
}

rewriter.eraseOp(op);
return mlir::success();
}
};

// Helper function to analyze the def-use chain of initTileOps. Currently we
// pattern match the following def-use chain as a candidate for
// load + tranpose optimization.
Expand Down Expand Up @@ -832,8 +853,10 @@ void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter,
patterns.insert<WGToSGInitTileOpPattern, WGToSGLoadTileOpPattern,
WGToSGTileMMAOpPattern, WGToSGStoreTileOpPattern,
WGToSGSCFForOpPattern, WGToSGUpdateTileOffsetOpPattern,
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, WGToSGVectorBroadcast,
WGToSGXeTileConvertLayout>(patterns.getContext(), converter, analysis);
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose,
WGToSGVectorBroadcast, WGToSGPrefetchOpPattern,
WGToSGXeTileConvertLayout>(patterns.getContext(),
converter, analysis);
patterns.insert<WGToSGElementWiseOpPattern<mlir::math::ExpOp, 1>,
WGToSGElementWiseOpPattern<mlir::arith::AddFOp, 2>,
WGToSGArithConstantOpPattern>(patterns.getContext(),
Expand Down Expand Up @@ -922,14 +945,13 @@ class XeTileWgToSgPass

target.addDynamicallyLegalOp<mlir::scf::YieldOp>(
[&](mlir::scf::YieldOp op) -> bool {
for (auto result : op.getResults()) {
// For cases with scf.if having hidden yield
for (auto result: op.getResults()) {
auto tileTy = mlir::dyn_cast<xetile::TileType>(result.getType());
if (!tileTy)
continue;
else if (!tileTy.getWgMap())
return true;
if (tileTy && tileTy.getWgMap())
return false;
}
return false;
return true;
});

target.addDynamicallyLegalOp<mlir::arith::ConstantOp, mlir::arith::AddFOp,
Expand All @@ -944,6 +966,19 @@ class XeTileWgToSgPass
return false;
});

target.addDynamicallyLegalOp<xetile::PrefetchTileOp>(
[&](xetile::PrefetchTileOp op) -> bool {
if (!op.getTile().getType().getWgMap())
return true;
else
return false;
});

target.addDynamicallyLegalOp<mlir::scf::IfOp>(
[&](mlir::scf::IfOp op) -> bool {
return true;
});

target.addIllegalOp<xetile::ConvertLayoutOp>();

target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
Expand Down
75 changes: 75 additions & 0 deletions test/Dialect/XeTile/Transforms/wg_to_sg_prefetch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s

gpu.module @test_prefetch{
gpu.func @preop_addexp_m512_n256_k4096(%arg0: memref<512x4096xf16>, %arg1: memref<256x4096xf16>, %arg2: memref<512x256xf32>, %arg3: memref<512x256xf32>) attributes {gemm_tiles_b = 1 : i64, gemm_tiles_x = dense<[2, 1, 2, 4]> : vector<4xi64>, gemm_tiles_y = dense<[1, 1, 1, 8]> : vector<4xi64>, habana_runner.num_inputs = 3 : i64, habana_runner.tests = [{inputs = [dense<1.000000e+00> : tensor<512x4096xf16>, dense<0.000000e+00> : tensor<256x4096xf16>, dense<1.900000e+01> : tensor<512x256xf32>], outputs = [dense<8.211000e+03> : tensor<512x256xf32>]}], physical_nd_range = dense<2> : vector<2xi64>, region_partition = 0 : i64, region_size = 2 : i64, syn.fusion_successful, syn.tensor_signature = (tensor<512x4096xf16>, tensor<256x4096xf16>, tensor<512x256xf32>) -> tensor<512x256xf32>, synFusionGenOps = 6 : i64, synFusionRequiredBeamSize = 1 : i64, synFusionTotalCost = 1000081812.83 : f64} {
%c2 = arith.constant 2 : index
%c2_0 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
gpu.launch blocks(%arg4, %arg5, %arg6) in (%arg10 = %c2, %arg11 = %c2_0, %arg12 = %c1) threads(%arg7, %arg8, %arg9) in (%arg13 = %c4, %arg14 = %c8, %arg15 = %c1) {
%c128 = arith.constant 128 : index
%c256 = arith.constant 256 : index
%c320 = arith.constant 320 : index
%c4096 = arith.constant 4096 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%cst = arith.constant {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} dense<0.000000e+00> : vector<128x256xf32>
%block_id_x = gpu.block_id x
%block_id_y = gpu.block_id y
%0 = arith.muli %block_id_x, %c256 : index
%1 = arith.muli %block_id_y, %c128 : index
%2 = arith.addi %0, %1 : index
%3 = xetile.init_tile %arg0[%2, %c0] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%4 = xetile.init_tile %arg1[%c0, %c0] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%5:2 = scf.for %arg16 = %c0 to %c320 step %c32 iter_args(%arg17 = %3, %arg18 = %4) -> (!xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) {
%18 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%19 = xetile.update_tile_offset %arg17, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
//CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<4x32xf16>
//CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<8x32xf16>
xetile.prefetch_tile %arg17 {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
xetile.prefetch_tile %arg18 {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
scf.yield %19, %18 : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
}
%6 = xetile.init_tile %arg0[%2, %c320] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%7 = xetile.init_tile %arg1[%c0, %c320] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%8 = xetile.init_tile %arg0[%2, %c0] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%9 = xetile.init_tile %arg1[%c0, %c0] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%10:5 = scf.for %arg16 = %c0 to %c4096 step %c32 iter_args(%arg17 = %cst, %arg18 = %6, %arg19 = %7, %arg20 = %8, %arg21 = %9) -> (vector<128x256xf32>, !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) {
%18 = xetile.update_tile_offset %arg21, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%19 = xetile.update_tile_offset %arg20, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%20 = xetile.update_tile_offset %arg19, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%21 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%22 = arith.addi %arg16, %c320 : index
%23 = arith.cmpi sge, %22, %c4096 : index
scf.if %23 {
} else {
//CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<4x32xf16>
//CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<8x32xf16>
xetile.prefetch_tile %arg18 {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
xetile.prefetch_tile %arg19 {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
}
%24 = xetile.load_tile %arg20 : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<128x32xf16>
%25 = arith.addf %24, %24 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x32xf16>
%26 = xetile.load_tile %arg21 : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<256x32xf16>
%27 = vector.transpose %26, [1, 0] {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<256x32xf16> to vector<32x256xf16>
%28 = math.exp %27 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<32x256xf16>
xegpu.compile_hint
%29 = xetile.tile_mma %25, %28, %cst {wg_map_a = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>, wg_map_b = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>, wg_map_c = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x32xf16>, vector<32x256xf16>, vector<128x256xf32> -> vector<128x256xf32>
xegpu.compile_hint
%30 = arith.addf %arg17, %29 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x256xf32>
scf.yield %30, %21, %20, %19, %18 : vector<128x256xf32>, !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
}
%11 = arith.muli %block_id_x, %c256 : index
%12 = arith.muli %block_id_y, %c128 : index
%13 = arith.addi %11, %12 : index
%14 = xetile.init_tile %arg2[%13, %c0] : memref<512x256xf32> -> !xetile.tile<128x256xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
%15 = xetile.load_tile %14 : !xetile.tile<128x256xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<128x256xf32>
%16 = arith.addf %10#0, %15 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x256xf32>
%17 = xetile.init_tile %arg3[%13, %c0] : memref<512x256xf32> -> !xetile.tile<128x256xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
xetile.store_tile %16, %17 : vector<128x256xf32>, !xetile.tile<128x256xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
gpu.terminator
}
gpu.return
}
}

0 comments on commit 09e7136

Please sign in to comment.