Skip to content

Commit

Permalink
Add xetile.convert_layout op (#936)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbpatel authored Oct 22, 2024
1 parent fee4e68 commit 6e634ae
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
16 changes: 16 additions & 0 deletions include/imex/Dialect/XeTile/IR/XeTileOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,22 @@ def XeTile_BroadcastOp: XeTile_Op<"broadcast", []> {
let hasVerifier = 1;
}

def XeTile_ConvertLayoutOp: XeTile_Op<"convert_layout", [AllTypesMatch<["source", "result"]>]> {
let summary = "Convert the sg layout of the input operand";
let description = [{
convert_layout with wg_map attributes remaps the SG layout
into a new layout which shuffles the data between subgroups with a workgroup
}];
let arguments = (ins XeTile_2DOr4DVector: $source,
XeTile_WorkGroupMapAttr: $wg_map_result,
OptionalAttr<XeTile_WorkGroupMapAttr>: $wg_map_source
);
let results = (outs XeTile_2DOr4DVector: $result);
let assemblyFormat = [{
$source attr-dict `:` type($source)
}];
}

def XeTile_LoadGatherOp: XeTile_Op<"load", [AllElementTypesMatch<["tile", "value"]>,
AllShapesMatch<["tile", "value", "mask"]>]> {
let summary = "load a set of scattered data points from memory.";
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/XeTile/IR/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#wg_map_b = #xetile.wg_map<sg_layout = [16, 1], sg_data = [16, 1]>
#wg_map_b2 = #xetile.wg_map<sg_layout = [4, 4], sg_data = [64, 64]>

#wg_map_new_layout = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>

func.func @test_init_tile_for_slm(%a: memref<1024x1024xf16, 3>) {
//CHECK: xetile.init_tile {{.*}}[8, 16] : memref<1024x1024xf16, 3> -> !xetile.tile<32x64xf16, #xetile.tile_attr<memory_space = 3 : i64>>
%1 = xetile.init_tile %a[8, 16] : memref<1024x1024xf16, 3> -> !xetile.tile<32x64xf16, #xetile.tile_attr<memory_space = 3>>
Expand Down Expand Up @@ -410,3 +412,11 @@ func.func @test_tile_mma_map(%a : vector<256x256xf16>, %b : vector<256x256xf16>,
vector<256x256xf16>, vector<256x256xf16>, vector<256x256xf32> -> vector<256x256xf32>
return
}

func.func @test_convert_layout(%source: vector<256x256xf16>) {
// CHECK: xetile.convert_layout {{.*}} {wg_map_result = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>} : vector<256x256xf16>
// CHECK: xetile.convert_layout {{.*}} {wg_map_result = #xetile.wg_map<sg_layout = [32, 1], sg_data = [8, 256]>, wg_map_source = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 64]>} : vector<256x256xf16>
%1 = xetile.convert_layout %source {wg_map_result = #wg_map_new_layout} : vector<256x256xf16>
%2 = xetile.convert_layout %source {wg_map_result = #wg_map_new_layout, wg_map_source = #wg_map_mma_b} : vector<256x256xf16>
return
}

0 comments on commit 6e634ae

Please sign in to comment.