Skip to content

Commit 09e7136

Browse files
authored
Add support for prefetch in wg to sg pass (#958)
* Add support for prefetch in wg to sg pass * Fix pre-commit
1 parent 0330284 commit 09e7136

File tree

2 files changed

+118
-8
lines changed

2 files changed

+118
-8
lines changed

lib/Dialect/XeTile/Transforms/WgToSg.cpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,27 @@ class WGToSGVectorBroadcast
750750
}
751751
};
752752

753+
754+
class WGToSGPrefetchOpPattern : public XeOneToNConversion<xetile::PrefetchTileOp> {
755+
using XeOneToNConversion<xetile::PrefetchTileOp>::XeOneToNConversion;
756+
757+
mlir::LogicalResult
758+
matchAndRewrite(xetile::PrefetchTileOp op, OpAdaptor adaptor,
759+
XeOneToNPatternRewriter &rewriter) const override {
760+
761+
auto L1 = op.getL1HintAttr();
762+
auto L2 = op.getL2HintAttr();
763+
auto L3 = op.getL3HintAttr();
764+
765+
for(auto tile : adaptor.getTile()) {
766+
rewriter.create<xetile::PrefetchTileOp>(op.getLoc(), tile, L1, L2, L3);
767+
}
768+
769+
rewriter.eraseOp(op);
770+
return mlir::success();
771+
}
772+
};
773+
753774
// Helper function to analyze the def-use chain of initTileOps. Currently we
754775
// pattern match the following def-use chain as a candidate for
755776
// load + tranpose optimization.
@@ -832,8 +853,10 @@ void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter,
832853
patterns.insert<WGToSGInitTileOpPattern, WGToSGLoadTileOpPattern,
833854
WGToSGTileMMAOpPattern, WGToSGStoreTileOpPattern,
834855
WGToSGSCFForOpPattern, WGToSGUpdateTileOffsetOpPattern,
835-
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose, WGToSGVectorBroadcast,
836-
WGToSGXeTileConvertLayout>(patterns.getContext(), converter, analysis);
856+
WGToSGSCFYieldOpPattern, WGToSGVectorTranspose,
857+
WGToSGVectorBroadcast, WGToSGPrefetchOpPattern,
858+
WGToSGXeTileConvertLayout>(patterns.getContext(),
859+
converter, analysis);
837860
patterns.insert<WGToSGElementWiseOpPattern<mlir::math::ExpOp, 1>,
838861
WGToSGElementWiseOpPattern<mlir::arith::AddFOp, 2>,
839862
WGToSGArithConstantOpPattern>(patterns.getContext(),
@@ -922,14 +945,13 @@ class XeTileWgToSgPass
922945

923946
target.addDynamicallyLegalOp<mlir::scf::YieldOp>(
924947
[&](mlir::scf::YieldOp op) -> bool {
925-
for (auto result : op.getResults()) {
948+
// For cases with scf.if having hidden yield
949+
for (auto result: op.getResults()) {
926950
auto tileTy = mlir::dyn_cast<xetile::TileType>(result.getType());
927-
if (!tileTy)
928-
continue;
929-
else if (!tileTy.getWgMap())
930-
return true;
951+
if (tileTy && tileTy.getWgMap())
952+
return false;
931953
}
932-
return false;
954+
return true;
933955
});
934956

935957
target.addDynamicallyLegalOp<mlir::arith::ConstantOp, mlir::arith::AddFOp,
@@ -944,6 +966,19 @@ class XeTileWgToSgPass
944966
return false;
945967
});
946968

969+
target.addDynamicallyLegalOp<xetile::PrefetchTileOp>(
970+
[&](xetile::PrefetchTileOp op) -> bool {
971+
if (!op.getTile().getType().getWgMap())
972+
return true;
973+
else
974+
return false;
975+
});
976+
977+
target.addDynamicallyLegalOp<mlir::scf::IfOp>(
978+
[&](mlir::scf::IfOp op) -> bool {
979+
return true;
980+
});
981+
947982
target.addIllegalOp<xetile::ConvertLayoutOp>();
948983

949984
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s
2+
3+
gpu.module @test_prefetch{
4+
gpu.func @preop_addexp_m512_n256_k4096(%arg0: memref<512x4096xf16>, %arg1: memref<256x4096xf16>, %arg2: memref<512x256xf32>, %arg3: memref<512x256xf32>) attributes {gemm_tiles_b = 1 : i64, gemm_tiles_x = dense<[2, 1, 2, 4]> : vector<4xi64>, gemm_tiles_y = dense<[1, 1, 1, 8]> : vector<4xi64>, habana_runner.num_inputs = 3 : i64, habana_runner.tests = [{inputs = [dense<1.000000e+00> : tensor<512x4096xf16>, dense<0.000000e+00> : tensor<256x4096xf16>, dense<1.900000e+01> : tensor<512x256xf32>], outputs = [dense<8.211000e+03> : tensor<512x256xf32>]}], physical_nd_range = dense<2> : vector<2xi64>, region_partition = 0 : i64, region_size = 2 : i64, syn.fusion_successful, syn.tensor_signature = (tensor<512x4096xf16>, tensor<256x4096xf16>, tensor<512x256xf32>) -> tensor<512x256xf32>, synFusionGenOps = 6 : i64, synFusionRequiredBeamSize = 1 : i64, synFusionTotalCost = 1000081812.83 : f64} {
5+
%c2 = arith.constant 2 : index
6+
%c2_0 = arith.constant 2 : index
7+
%c4 = arith.constant 4 : index
8+
%c8 = arith.constant 8 : index
9+
%c1 = arith.constant 1 : index
10+
gpu.launch blocks(%arg4, %arg5, %arg6) in (%arg10 = %c2, %arg11 = %c2_0, %arg12 = %c1) threads(%arg7, %arg8, %arg9) in (%arg13 = %c4, %arg14 = %c8, %arg15 = %c1) {
11+
%c128 = arith.constant 128 : index
12+
%c256 = arith.constant 256 : index
13+
%c320 = arith.constant 320 : index
14+
%c4096 = arith.constant 4096 : index
15+
%c32 = arith.constant 32 : index
16+
%c0 = arith.constant 0 : index
17+
%cst = arith.constant {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} dense<0.000000e+00> : vector<128x256xf32>
18+
%block_id_x = gpu.block_id x
19+
%block_id_y = gpu.block_id y
20+
%0 = arith.muli %block_id_x, %c256 : index
21+
%1 = arith.muli %block_id_y, %c128 : index
22+
%2 = arith.addi %0, %1 : index
23+
%3 = xetile.init_tile %arg0[%2, %c0] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
24+
%4 = xetile.init_tile %arg1[%c0, %c0] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
25+
%5:2 = scf.for %arg16 = %c0 to %c320 step %c32 iter_args(%arg17 = %3, %arg18 = %4) -> (!xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) {
26+
%18 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
27+
%19 = xetile.update_tile_offset %arg17, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
28+
//CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<4x32xf16>
29+
//CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<8x32xf16>
30+
xetile.prefetch_tile %arg17 {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
31+
xetile.prefetch_tile %arg18 {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
32+
scf.yield %19, %18 : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
33+
}
34+
%6 = xetile.init_tile %arg0[%2, %c320] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
35+
%7 = xetile.init_tile %arg1[%c0, %c320] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
36+
%8 = xetile.init_tile %arg0[%2, %c0] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
37+
%9 = xetile.init_tile %arg1[%c0, %c0] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
38+
%10:5 = scf.for %arg16 = %c0 to %c4096 step %c32 iter_args(%arg17 = %cst, %arg18 = %6, %arg19 = %7, %arg20 = %8, %arg21 = %9) -> (vector<128x256xf32>, !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) {
39+
%18 = xetile.update_tile_offset %arg21, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
40+
%19 = xetile.update_tile_offset %arg20, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
41+
%20 = xetile.update_tile_offset %arg19, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
42+
%21 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
43+
%22 = arith.addi %arg16, %c320 : index
44+
%23 = arith.cmpi sge, %22, %c4096 : index
45+
scf.if %23 {
46+
} else {
47+
//CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<4x32xf16>
48+
//CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<8x32xf16>
49+
xetile.prefetch_tile %arg18 {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
50+
xetile.prefetch_tile %arg19 {l1_hint = #xetile.cache_hint<uncached>, l2_hint = #xetile.cache_hint<cached>} : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
51+
}
52+
%24 = xetile.load_tile %arg20 : !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<128x32xf16>
53+
%25 = arith.addf %24, %24 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x32xf16>
54+
%26 = xetile.load_tile %arg21 : !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<256x32xf16>
55+
%27 = vector.transpose %26, [1, 0] {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<256x32xf16> to vector<32x256xf16>
56+
%28 = math.exp %27 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<32x256xf16>
57+
xegpu.compile_hint
58+
%29 = xetile.tile_mma %25, %28, %cst {wg_map_a = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>, wg_map_b = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>, wg_map_c = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x32xf16>, vector<32x256xf16>, vector<128x256xf32> -> vector<128x256xf32>
59+
xegpu.compile_hint
60+
%30 = arith.addf %arg17, %29 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x256xf32>
61+
scf.yield %30, %21, %20, %19, %18 : vector<128x256xf32>, !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [8, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr<wg_map = <sg_layout = [8, 4], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
62+
}
63+
%11 = arith.muli %block_id_x, %c256 : index
64+
%12 = arith.muli %block_id_y, %c128 : index
65+
%13 = arith.addi %11, %12 : index
66+
%14 = xetile.init_tile %arg2[%13, %c0] : memref<512x256xf32> -> !xetile.tile<128x256xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
67+
%15 = xetile.load_tile %14 : !xetile.tile<128x256xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<128x256xf32>
68+
%16 = arith.addf %10#0, %15 {map = #xetile.wg_map<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<128x256xf32>
69+
%17 = xetile.init_tile %arg3[%13, %c0] : memref<512x256xf32> -> !xetile.tile<128x256xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
70+
xetile.store_tile %16, %17 : vector<128x256xf32>, !xetile.tile<128x256xf32, #xetile.tile_attr<wg_map = <sg_layout = [4, 8], sg_data = [32, 32]>, inner_blocks = [], memory_space = 0 : i32, scattered = false>>
71+
gpu.terminator
72+
}
73+
gpu.return
74+
}
75+
}

0 commit comments

Comments
 (0)