Skip to content

Commit

Permalink
Update XeTile.md
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianhui-Li authored and silee2 committed Dec 24, 2024
1 parent 5c94209 commit b0d8c48
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions docs/rfcs/XeTile.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ Attribute `padding` specifies the padding value for the out-of-boundary access.
`store_tile` stores a vector to memory. Padding attributes are not supported.
```mlir
xetile.store_tile %tile_a, %vector_a :
vector<64x64xbf16> into tile<64x64xbf16>
vector<64x64xbf16>, tile<64x64xbf16>
```
`store_tile` stores a tile according to the tile's `order` attribute. Regardless of the `order` attribute value, the vector's dimensions must match exactly the Tile's dimensions.
```mlir
#tile_attr = #xetile.tile_attr<order = [0, 1]>
%vector_a = xetile.store_tile %tile_a :
vector<64x32xb16> to tile<64x32xbf16, #tile_attr>
vector<64x32xb16>, tile<64x32xbf16, #tile_attr>
```

`prefetch_tile` prefetches the tile to cache. Just like memref.preftech, the locality hint ranges from locality<0> (no locality) to locality<3> (extremely local keep in cache).
Expand Down Expand Up @@ -155,7 +155,7 @@ xetile.atomic_rmw reuses the arith dialect attribute, mlir::arith::AtomicRMWKind
```

## support for load_gather and store_scatter (experimental)
`init_tile` can create a tile with each element's address being explictly specified. The tile is created with a base memref and offsets for all elements to be loaded. The offsets and result tile can be either 1D or 2D. The resule tile has a `scatter` attribute to distinguish it from the regular tile.
`init_tile` can create a tile with each element's address being explictly specified. The tile is created with a base memref and offsets for all elements to be loaded. The offsets and result tile can be either 1D or 2D. The result tile has a `scatter` attribute to distinguish it from the regular tile.
```mlir
%tile0 = xetile.init_tile %base_memref, %tile_offsets:
memref<?xbf16>, vector<256xindex> into tile<256xbf16, #scatter>
Expand All @@ -166,12 +166,12 @@ xetile.atomic_rmw reuses the arith dialect attribute, mlir::arith::AtomicRMWKind
`load_gather` (aka. load) loads data with prepared tile and mask. Attribute `padding` specifies the padding value for the out-of-boundary access. The default value is zero.
```mlir
%vector_a = xetile.load_gather %tile_0, %mask, {padding = 1.0} :
tile<1x256xbf16, #scatter> into vector<1x256xbf16>
tile<1x256xbf16, #scatter>, vector<1x256xi1>, vector<1x256xbf16>
```
`store_scatter` stores a 2d vector to a 2D tile with `scatter` attribute.
```mlir
xetile.store_scatter %vector_a, %mask, %tile_0 :
vector<1x256xbf16> into tile<1x256xbf16, #scatter>
vector<1x256xbf16>, vector<1x256xi1>, tile<1x256xbf16, #scatter>
```

## Workgroup Level XeTile extension (experimental)
Expand Down Expand Up @@ -412,7 +412,7 @@ func.func @test_gemm(%a : memref<4096x4096xf16>,
         
     prefetch_tile %1 : tile<256x32xf16, #mp_a_pfh>               // sg_layout=[32,1]
          prefetch_tile %2  : tile<32x256xf16, #mp_a_pfh>              // sg_layout=[4,8]
          %6 = tile_mma %4, %5 {#mp_a #mp_b #mp_c} %4, %10 : (vector<256x32xf16>, vector<32x256xf16>) -> vector<256x256xf32> //sg_layout=[8,4]
          %6 = tile_mma %4, %10 {#mp_a #mp_b #mp_c} : (vector<256x32xf16>, vector<32x256xf16>) -> vector<256x256xf32> //sg_layout=[8,4]
          %1 = update_tile_offset   %1, %c0, %c32 :  tile<256x32xf16, #mp_a>
          %2 = update_tile_offset   %2, %c32, %c0 :  tile<32x256xf16, #mp_b>
          %1p = update_tile_offset   %1p, %c0, %c32 :  tile<256x32xf16, #mp_a_pft>
Expand Down Expand Up @@ -575,7 +575,7 @@ func.func @test_gemm(%a : memref<4096x4096xf16>,
%a4_slm’ = update_tile_offset %a4_slm, %c0, %slm_offset: tile<256x32xf16, #mp_a_pft>
%b4_slm’ = update_tile_offset %b4_slm, %slm_offset, %c0 : tile<32x256xf16, #mp_b_pft>
%c_r = tile_mma %a1_rr, %b1_rr #mp_a #mp_b #mp_c:
%c_r = tile_mma %a1_rr, %b1_rr {#mp_a #mp_b #mp_c}:
(vector<256x32xf16>, vector<32x256xf16>) -> vector<256x256xf32> // sg_layout=[8,8], sg_data=[32,32]
gpu.barrier
Expand Down Expand Up @@ -646,7 +646,7 @@ func.func @test_gemm(%a : memref<4096x4096xf16>,
   %a1_load = init_tile %a[%i, %c0] : memref<4096x4096xf16> -> tile<512x32xf16, #mp_a>
   %b1_load = init_tile %b[%c0, %j] : memref<4096x4096xf16> -> tile<32x256xf16, #mp_b>
%c = init_tile %c[%i, %j] : memref<4096x4096xf32> -> tile<512x256xf32, #mp_c>
%c_tile = init_tile %c[%i, %j] : memref<4096x4096xf32> -> tile<512x256xf32, #mp_c>
scf.for %k= %c0 to %c4096 step %c32 {
%a1_r = load_tile %a1_load : tile<256x32xf16 #mp_a > -> vector<512x32xf16>
Expand All @@ -667,9 +667,9 @@ func.func @test_gemm(%a : memref<4096x4096xf16>,
%a1_load = update_tile_offset %a1_load, %c0, %c32 : tile<512x32xf16, #mp_a>
%a2_load = update_tile_offset %b1_load, %c32, %c0 : tile<32x256xf16, #mp_b>
%6 = tile_mma %4, %5 #mp_a #mp_b #mp_c %4, %10 : (vector<512x32xf16>, vector<32x256xf16>) -> vector<512x256xf32>
%6 = tile_mma %a1_r, %b1_r {#mp_a #mp_b #mp_c} : (vector<512x32xf16>, vector<32x256xf16>) -> vector<512x256xf32>
}
 store_tile %3, %6: (tile<512x256xf32, #mp_c>, vector<512x256xf32>)
 store_tile %c_tile, %6: (tile<512x256xf32, #mp_c>, vector<512x256xf32>)
}
}
}
Expand Down

0 comments on commit b0d8c48

Please sign in to comment.