Skip to content

Commit

Permalink
Rewrite XeTileToXeGPU pass to take 2D shape as inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 committed Dec 13, 2024
1 parent d40c423 commit 1deae85
Show file tree
Hide file tree
Showing 7 changed files with 667 additions and 106 deletions.
15 changes: 7 additions & 8 deletions include/imex/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def ConvertXeTileToXeGPU: Pass<"convert-xetile-to-xegpu", "::mlir::gpu::GPUModul
func.func @sglevel_tiled_load_tile(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<2x1x8x16xf16>
%2 = xetile.load_tile %1 : !xetile.tile<2x1x8x16xf16> -> vector<2x1x8x16xf16>
%1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<8x16xf16>
%2 = xetile.load_tile %1 : !xetile.tile<8x16xf16> -> vector<8x16xf16>
return
}

Expand All @@ -379,11 +379,7 @@ def ConvertXeTileToXeGPU: Pass<"convert-xetile-to-xegpu", "::mlir::gpu::GPUModul
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c64] {boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16>
%c8 = arith.constant 8 : index
%c64_0 = arith.constant 64 : index
%1 = xegpu.create_nd_tdesc %arg0[%c8, %c64_0] {boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16>
%2 = xegpu.load_nd %0 {l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<uncached>} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%3 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<uncached>, l2_hint = #xegpu.cache_hint<uncached>, l3_hint = #xegpu.cache_hint<uncached>} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
return
}
}];
Expand All @@ -396,9 +392,12 @@ def ConvertXeTileToXeGPU: Pass<"convert-xetile-to-xegpu", "::mlir::gpu::GPUModul
"::mlir::arith::ArithDialect",
];
let options = [
Option<"device", "device", "std::string",
Option<"device", "device", "std::string",
/*default=*/"\"pvc\"",
"gpu platform architecture where these ops are running">
"gpu platform architecture where these ops are running">,
Option<"EnableTransform", "enable-2d-transform", "bool",
/*default=*/"false",
"Using 2D transform or 4D Conversion.">
];
}

Expand Down
Loading

0 comments on commit 1deae85

Please sign in to comment.