From 8a5bcf5df2334c4842ce64ca86aa2eb9791b0481 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 17 Dec 2024 14:41:23 -0600 Subject: [PATCH] Remove unused definition and code (#987) * Remove Unpack and Pack * Remove inner_block attribute * Remove 4D vector support * and some others --- docs/rfcs/XeTile.md | 69 +--- include/imex/Dialect/XeTile/IR/XeTileAttrs.td | 17 - include/imex/Dialect/XeTile/IR/XeTileOps.td | 128 ++----- include/imex/Dialect/XeTile/IR/XeTileTypes.td | 18 +- .../imex/Dialect/XeTile/Transforms/Passes.h | 6 - .../Transforms/XeTileOneToNConversion.h | 19 +- include/imex/Utils/XeCommon.h | 341 +----------------- lib/Dialect/XeTile/IR/XeTileDialect.cpp | 16 +- lib/Dialect/XeTile/IR/XeTileOps.cpp | 196 +--------- .../XeTile/Transforms/Canonicalization.cpp | 3 +- lib/Dialect/XeTile/Transforms/WgToSg.cpp | 14 +- .../Transforms/XeTileOneToNConversion.cpp | 28 +- lib/Utils/XeCommon.cpp | 31 +- test/Dialect/XeTile/IR/canonicalize.mlir | 20 - test/Dialect/XeTile/IR/invalid.mlir | 83 +---- test/Dialect/XeTile/IR/ops.mlir | 69 +--- .../XeTile/Transforms/WgToSg/broadcast.mlir | 20 +- .../XeTile/Transforms/WgToSg/btranspose.mlir | 28 +- .../XeTile/Transforms/WgToSg/gemm_batch.mlir | 20 +- .../Transforms/WgToSg/gemm_batch_oob.mlir | 20 +- .../XeTile/Transforms/WgToSg/gemm_postop.mlir | 6 +- .../XeTile/Transforms/WgToSg/prefetch.mlir | 52 +-- .../XeTile/Transforms/WgToSg/unit_tests.mlir | 4 +- 23 files changed, 142 insertions(+), 1066 deletions(-) delete mode 100644 test/Dialect/XeTile/IR/canonicalize.mlir diff --git a/docs/rfcs/XeTile.md b/docs/rfcs/XeTile.md index a1b47310e..048bcbc96 100644 --- a/docs/rfcs/XeTile.md +++ b/docs/rfcs/XeTile.md @@ -29,8 +29,7 @@ XeTile provides a middle-level abstraction for matmul operation and sits between |tile_transpose | operation ::=xetile.tile_transpose $vec $permuation_dims attr_dict: type($vec) -> type($res) | %vector_a = xetile.tile_transpose %vector_b [1, 0]: vector<64x32xfloat> into vector<32x64xfloat> | |tile_reduce | operation ::=xetile.tile_reduce \<$kind\> $src $reduction_dims attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_reduce \ %vector_b [1]: vector<64x32xfloat> into vector<64x1xfloat> | |tile_broadcast | operation ::=xetile.tile_broadcast $src $broadcast_dims attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_broadcast %vector_b[0]: vector<1x32xfloat> into vector<64x32xfloat> | -|tile_pack* | operation ::=xetile.tile_pack $matA attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_pack %vector_b {inner_blocks=array} : vector<64x32xfloat> into vector<4x2x16x16xfloat> | -|tile_unpack* | operation ::=xetile.tile_upack $matA attr_dict: type($value) -> type($res) | %vector_a = xetile.tile_unpack %vector_b {inner_blocks=array} : vector<1x2x64x16xfloat> into vector<64x32xbf16> | + *Operations only used to support internal lowering. @@ -155,70 +154,6 @@ xetile.atomic_rmw reuses the arith dialect attribute, mlir::arith::AtomicRMWKind %vector_a = xetile.tile_broadcast %vector_b [0]: vector<1x32xfloat> into vector<64x32xfloat> ``` -## Internal Operations to support gradual lowering -The 2D XeTile IR needs to be lowered in an intermediate form to support `blocking` optimization. The `blocking` optimization loads the tile in blocks and feed the block to matrix hardware. Since the load block size and matrix hardware size are not necessary same, we need to represent the data block in some form to assist the optimization. Conceptually, when a 2D tile data being loaded with a specified block size, the vector represents the 2D tile in 4D block layout. So we uses 4D dimension vector to describe the data being loaded with the block size. - -`init_tile` with an `inner_block` for 2D block access of the base matrix. The `inner_blocks` attribute describes the block size for each memory load and store operation when the tile is being loaded. The block size for load may be larger than the block size for MMA operation. The output tile carries the `inner_block` attribute in its attribute set. - -```mlir - #tile_attr = #xetile.tile_attr - %tile0 = xetile.init_tile %base_memref, [%tile_offset:2]: - memref<128x128xbf16> into tile<64x32xbf16, #tile_attr> -``` - -`load_tile` loads a 2D tile with an `inner_block` attribute to 4D vector. -```mlir - #tile_attr = #xetile.tile_attr - %vector_a = xetile.load_tile %tile_a : - tile<64x32xbf16, #tile_attr> into vector<4x2x16x16xb16> -``` -`store_tile` stores a 4D vector to a 2D tile with an `inner_block`. -```mlir - #tile_attr = #xetile.tile_attr - xetile.store_tile %vector_a, %tile_a : - vector<4x2x16x16xb16> into tile<64x32xbf16, #tile_attr> -``` -`atomic_rmw_tile` performs atomic operation on 4D vectors. -```mlir -#tile_attr = #xetile.tile_attr -%vector_a = atomic_rmw_tile %value, %tile: vector<8x48x16xbf16>, tile<64x64xbf16, #tile_attr> to vector<8x4x8x16xbf16> -``` - -With the data being presented as 4D vector, all the vector based XeTile operations are required to support blocking. -`tile_mma` works on 4D vectors. Since dimension 1 is split into dimensions 1 and 3, the reduction of matrix multiplication is along these two dimensions. -```mlir - %vector_c = xetile.tile_mma %vector_a, %vector_b, %vector_c : - vector<8x4x8x8xbf16>, vector<4x8x8x16xbf16>, vector<8x8x8x16xfloat> - into vector<8x8x8x16xfloat> -``` -`tile_reduce` follows the vector.multi-reduction semantics and can be applied to 4D vector. The tile_reduce on 4D vector is an internal operation and only used in the transformation passes to support gradual lowering. -```mlir - %vector_a = xetile.tile_reduce %vector_b [1, 3]: vector<8x4x8x16xfloat> into vector<8x1x8x1float> -``` - -`tile_broadcast` broadcast 4D vector. The input is expected to be first reshaped from 1D vector to 2D vector, and then blocked to 4D. -```mlir - %vector_a = xetile.tile_broadcast %vector_b [1, 3]: vector<8x1x8x1xfloat> into vector<8x4x8x16xfloat> -``` - -`tile_transpose` doesn't have support 4D vector. The transpose is usually implemented by saving and restoring from the share local memory. To support this, we relax the restriction of tile_load and tile_store so that they can load 2D from share local memory. - -`tile_pack` and `tile_unpack` are introduced to support the gradual lowering. It allows the XeTile IR to be blocked with different block size, and then try to find a good blocking strategy with minimum tile_pack and tile_unpack overhead. - -`tile_pack` packs a 2D vector, representing the loaded value from 2D tile, to a 4D vector with an inner block size. The 4D vector was introduced to support blocking to fit the hardware matrix operation sizes.  The blocking follows an implicit rule: out_dim[0] = in_dim[0]/inner_blocks[0] , out_dim[1] = in_dim[1]/inner_blocks[1], out_dim[2] = inner_blocks[0], and out_dim[3] = inner_blocks[1]. The dim[2] and dim[3] of result 4D vector must be same as the size of `inner_blocks` attribute. - -```mlir - %0 = xetile.tile_pack %1 {inner_blocks = array} - : vector<64x32xf32> -> vector<4x2x16x16xf32> -``` -`tile_unpack` unpacks a 4D blocked vector back to original unpacked 2D vector. -`tile_unpack` -```mlir - %0 = xetile.tile_unpack %1 {inner_blocks = array} - : vector<1x2x64x16xf32> -> vector<64x32xf32> -``` -The tile_pack and tile_unpack operation is similar to pack and unpack operation of tensor dialect. The source vector must be a 2D dimension vector, and no permutation is allowed for the result 4D vector, so effectively the blocking effect is identical to tensor pack/unpack operation with inner_dims_pos = [0,1] inner_dims_pos = [0, 1]. - ## support for load_gather and store_scatter (experimental) `init_tile` can create a tile with each element's address being explictly specified. The tile is created with a base memref and offsets for all elements to be loaded. The offsets and result tile can be either 1D or 2D. The resule tile has a `scatter` attribute to distinguish it from the regular tile. ```mlir @@ -245,7 +180,7 @@ The tile_pack and tile_unpack operation is similar to pack and unpack operation Below is an example. ```mlir #wg_map_a = #xetile.wg_map - #tile_attr = #xetile.tile_attr > + #tile_attr = #xetile.tile_attr %wg_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #tile_attr> ``` diff --git a/include/imex/Dialect/XeTile/IR/XeTileAttrs.td b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td index 6d1e5a092..3203fa92c 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileAttrs.td +++ b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td @@ -63,27 +63,12 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> { OptionalParameter<"xetile::SubGroupMapAttr">:$sg_map, OptionalParameter<"xetile::WorkGroupMapAttr">:$wg_map, DefaultValuedParameter<"mlir::DenseI32ArrayAttr", "mlir::DenseI32ArrayAttr::get($_ctxt, {1, 0})">:$order, - OptionalParameter<"mlir::DenseI64ArrayAttr">:$inner_blocks, OptionalParameter<"mlir::Attribute">:$memory_space, OptionalParameter<"mlir::BoolAttr">:$scattered ); let assemblyFormat = "`<` struct(params) `>`"; let genVerifyDecl = true; let builders = [ - AttrBuilder<(ins CArg<"xetile::SubGroupMapAttr", "{}">:$sg_map, - CArg<"xetile::WorkGroupMapAttr", "{}">:$wg_map, - CArg<"llvm::ArrayRef", "{1, 0}">:$order, - CArg<"llvm::ArrayRef", "{}">:$inner_blocks, - CArg<"int", "0">:$memory_space, - CArg<"bool", "false">:$scattered), - [{ - mlir::Type intType = mlir::IntegerType::get($_ctxt, 32); - mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered); - mlir::DenseI64ArrayAttr blkAttr = inner_blocks.empty()? mlir::DenseI64ArrayAttr(): - mlir::DenseI64ArrayAttr::get($_ctxt, inner_blocks); - return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order), - blkAttr, mlir::IntegerAttr::get(intType, memory_space), scatteredAttr); - }]>, AttrBuilder<(ins CArg<"llvm::ArrayRef", "{1, 0}">:$order, CArg<"int", "0">:$memory_space, CArg<"bool", "false">:$scattered), [{ @@ -91,7 +76,6 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> { mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered); return $_get($_ctxt, xetile::SubGroupMapAttr(), xetile::WorkGroupMapAttr(), mlir::DenseI32ArrayAttr::get($_ctxt, order), - mlir::DenseI64ArrayAttr(), mlir::IntegerAttr::get(intType, memory_space), scatteredAttr); }]>, AttrBuilder<(ins CArg<"xetile::SubGroupMapAttr", "{}">:$sg_map, @@ -102,7 +86,6 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> { mlir::Type intType = mlir::IntegerType::get($_ctxt, 32); mlir::BoolAttr scatteredAttr = mlir::BoolAttr::get($_ctxt, scattered); return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order), - mlir::DenseI64ArrayAttr(), mlir::IntegerAttr::get(intType, memory_space), scatteredAttr); }]> ]; diff --git a/include/imex/Dialect/XeTile/IR/XeTileOps.td b/include/imex/Dialect/XeTile/IR/XeTileOps.td index 44e446a42..b55a03f8c 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileOps.td +++ b/include/imex/Dialect/XeTile/IR/XeTileOps.td @@ -38,10 +38,6 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments, memref or an address is used as the base, it is required to specify the shape and strides of the memory region described by the tile. - Optionally, the tile can be described in blocked layout as well. This is done by specifying - an "inner_blocks" attribute which describes the size (rows and cols) of the block. This attribute - is used by later lowering passes to detremine the 2D block load/store sizes. - The operation takes in the following arguments: * source: Source can be static/dynamic shaped memref or an address (i64) * offsets: offsets into the "source" memref or address at which to @@ -112,7 +108,7 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments, OptionalAttr: $const_offsets, OptionalAttr: $const_sizes, OptionalAttr: $const_strides, - Optional>: $indices); + Optional>: $indices); let results = (outs XeTile: $tile); @@ -285,7 +281,9 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments, } -def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> { +def XeTile_LoadTileOp : XeTile_Op<"load_tile", [ + AllElementTypesMatch<["source", "value"]>, + AllShapesMatch<["source", "value"]>]> { let summary = "Loads a tile into a register region"; let description = [{ "load_tile" operation loads the values of a tile into a register region with 2D or 4D layout. @@ -309,12 +307,6 @@ def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> { %4 = xetile.load_tile %src { padding = 1.0 : f32} : !xetile.tile<64x32xf32> -> vector<32x64xf32> ``` - - Example 3: loading into a 4D register region. - ```mlir - %4 = xetile.load_tile %src : !xetile.tile<64x32xf32, #xetile.tile_attr> - -> vector<8x2x8x16xf32> - ``` }]; let arguments = (ins XeTile: $source, @@ -323,10 +315,9 @@ def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> { OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); - let results = (outs XeTile_2DOr4DVector: $value); + let results = (outs XeTile_2DVector: $value); let assemblyFormat = "$source attr-dict `:` qualified(type($source)) `->` type($value)"; - let hasVerifier = true; let extraClassDeclaration = [{ // padding value defaults to zero in the appropriate type if its not specified @@ -342,7 +333,9 @@ def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> { } -def XeTile_StoreTileOp : XeTile_Op<"store_tile", []> { +def XeTile_StoreTileOp : XeTile_Op<"store_tile", [ + AllElementTypesMatch<["value", "tile"]>, + AllShapesMatch<["value", "tile"]>]> { let summary = "stores a register region into memory"; let description = [{ "store_tile" operation can be used to store a register region into a 2D memory region @@ -357,16 +350,10 @@ def XeTile_StoreTileOp : XeTile_Op<"store_tile", []> { ```mlir xetile.store_tile %value, %dst : vector<64x32xf32>, !tile<64x32xf32> ``` - - Example 2: storing a 4D register region - ```mlir - xetile.store_tile %value, %dst : vector<8x2x8x16xf32>, - !tile<64x32xf32, #xetile.tile_attr> - ``` }]; let arguments = (ins - XeTile_2DOr4DVector: $value, + XeTile_2DVector: $value, XeTile: $tile, OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, @@ -375,7 +362,6 @@ def XeTile_StoreTileOp : XeTile_Op<"store_tile", []> { let assemblyFormat = [{ $value`,`` `$tile attr-dict `:` qualified(type($value)) `,` qualified(type($tile)) }]; - let hasVerifier = true; } def XeTile_PrefetchTileOp : XeTile_Op<"prefetch_tile", []> { @@ -433,26 +419,18 @@ def XeTile_TileMMAOp : XeTile_Op<"tile_mma", []> { %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec : vector<64x32xf32>, vector<32x128xf32>, vector<64x128xf32> -> vector<64x128xf32> ``` - - Example 3: tile_mma on 4D vectors of A, B and, C - ```mlir - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec - : vector<8x4x8x8xf32>, vector<4x8x8x16xf32>, vector<8x8x8x16xf32> -> vector<8x8x8x16xf32> - ``` - - }]; let arguments = (ins - XeTile_2DOr4DVector: $a, - XeTile_2DOr4DVector: $b, - Optional: $c, + XeTile_2DVector: $a, + XeTile_2DVector: $b, + Optional: $c, OptionalAttr: $wg_map_a, OptionalAttr: $wg_map_b, OptionalAttr: $wg_map_c ); - let results = (outs XeTile_2DOr4DVector: $output); + let results = (outs XeTile_2DVector: $output); let assemblyFormat = [{ $a `,` $b (`,` $c^)? attr-dict `:` type($a)`,` type($b) (`,` type($c)^)? `->` type($output) }]; @@ -497,7 +475,7 @@ def XeTile_UpdateTileOffsetOp : XeTile_Op<"update_tile_offset", [AttrSizedOperan XeTile: $tile, Optional: $offset_x, Optional: $offset_y, - Optional>:$indices); + Optional>:$indices); let results = (outs XeTile: $result @@ -508,60 +486,6 @@ def XeTile_UpdateTileOffsetOp : XeTile_Op<"update_tile_offset", [AttrSizedOperan }]; } -def XeTile_TilePackOp : XeTile_Op<"tile_pack", [Pure]> { - let summary = "pack 2D vector into a 4D blocked vector"; - let description = [{ - "tile_pack" operation is used for converting a 2D vector into a blocked 4D vector. The - block size is specified by the `inner_blocks` atrribute. The outermost 2 dimensions of - the output specify the number of blocks along row and col dimensions, and innermost 2 - dimensions of the output is equal to the dimensions of the `inner_blocks` i.e. blocking - follows this rule: - out_vec_shape[0] = in_vec_shape[0]/inner_blocks[0] - out_vec_shape[1] = in_vec_shape[1]/inner_blocks[1] - out_vec_shape[2] = inner_blocks[0] - out_vec_shape[3] = inner_blocks[1] - - Example 1: - ```mlir - %1 = xetile.tile_pack %0 {inner_blocks = array} - : vector<64x32xf32> -> vector<4x2x16x16xf32> - ``` - }]; - - let arguments = (ins XeTile_2DVector: $in_vec, DenseI64ArrayAttr: $inner_blocks); - let results = (outs XeTile_4DVector: $out_vec); - - let assemblyFormat = "$in_vec attr-dict `:` qualified(type($in_vec)) `->` type($out_vec)"; - let hasVerifier = true; - let hasFolder = 1; -} - -def XeTile_TileUnpackOp : XeTile_Op<"tile_unpack", [Pure]> { - let summary = "unpack a blocked 4D vector into a 2D vector"; - let description = [{ - "tile_unpack" operation performs the reverse of "tile_pack". Given a 4D vector and - an `inner_block` attribute, this operation unpacks the blocked vector into a 2D vector. - Similar to "tile_pack" this operation follows the rule: - out_vec_shape[0] == in_vec_shape[0]*inner_blocks[0] - out_vec_shape[1] == in_vec_shape[1]*inner_blocks[1] - inner_blocks[0] == in_vec_shape[2] - inner_blocks[1] == in_vec_shape[3] - - Example 1: - ```mlir - %1 = xetile.tile_unpack %0 {inner_blocks = array} - : vector<4x2x16x16xf16> -> vector<64x32xf16> - ``` - }]; - - let arguments = (ins XeTile_4DVector: $in_vec, DenseI64ArrayAttr: $inner_blocks); - let results = (outs XeTile_2DVector: $out_vec); - - let assemblyFormat = "$in_vec attr-dict `:` qualified(type($in_vec)) `->` type($out_vec)"; - let hasVerifier = true; - let hasFolder = 1; -} - def XeTile_AtomicRMWOp : XeTile_Op<"atomic_rmw", []> { let summary = "performs a read modify write operation that is free from data races."; let description = [{ @@ -576,9 +500,9 @@ def XeTile_AtomicRMWOp : XeTile_Op<"atomic_rmw", []> { ``` }]; let arguments = (ins XeTile_AtomicRMWKindAttr:$kind, - XeTile_2DOr4DVector:$value, + XeTile_2DVector:$value, XeTile:$tile); - let results = (outs XeTile_2DOr4DVector:$result); + let results = (outs XeTile_2DVector:$result); let assemblyFormat = [{ $kind $value `,` $tile attr-dict `:` qualified(type($value)) `,` qualified(type($tile)) `->` qualified(type($result)) @@ -591,9 +515,9 @@ def XeTile_TransposeOp: XeTile_Op<"transpose", []> { It has the same semantic with `vector.transpose`, but limits the vector to be 2D. }]; - let arguments = (ins XeTile_2DOr4DVector: $vector, + let arguments = (ins XeTile_2DVector: $vector, DenseI64ArrayAttr:$permutation); - let results = (outs XeTile_2DOr4DVector: $result); + let results = (outs XeTile_2DVector: $result); let assemblyFormat = [{ $vector `,` $permutation attr-dict `:` type($vector) `->` type($result) }]; @@ -609,9 +533,9 @@ def XeTile_ReductionOp: XeTile_Op<"reduction", []> { }]; let arguments = (ins Vector_CombiningKindAttr: $kind, - XeTile_2DOr4DVector: $source, + XeTile_2DVector: $source, DenseI64ArrayAttr: $reduction_dims); - let results = (outs XeTile_2DOr4DVector: $result); + let results = (outs XeTile_2DVector: $result); let assemblyFormat = [{ $kind `,` $source $reduction_dims attr-dict `:` type($source) `->` type($result) }]; @@ -622,9 +546,9 @@ def XeTile_ReductionOp: XeTile_Op<"reduction", []> { def XeTile_BroadcastOp: XeTile_Op<"broadcast", []> { let summary = "broadcast a vector from 1D to 2D."; - let arguments = (ins XeTile_2DOr4DVector: $source, + let arguments = (ins XeTile_2DVector: $source, DenseI64ArrayAttr: $broadcast_dim); - let results = (outs XeTile_2DOr4DVector: $result); + let results = (outs XeTile_2DVector: $result); let assemblyFormat = [{ $source $broadcast_dim attr-dict `:` type($source) `->` type($result) }]; @@ -637,11 +561,11 @@ def XeTile_ConvertLayoutOp: XeTile_Op<"convert_layout", [AllTypesMatch<["source" convert_layout with wg_map attributes remaps the SG layout into a new layout which shuffles the data between subgroups with a workgroup }]; - let arguments = (ins XeTile_2DOr4DVector: $source, + let arguments = (ins XeTile_2DVector: $source, XeTile_WorkGroupMapAttr: $wg_map_result, OptionalAttr: $wg_map_source ); - let results = (outs XeTile_2DOr4DVector: $result); + let results = (outs XeTile_2DVector: $result); let assemblyFormat = [{ $source attr-dict `:` type($source) }]; @@ -663,7 +587,7 @@ def XeTile_LoadGatherOp: XeTile_Op<"load", [AllElementTypesMatch<["tile", "value OptionalAttr: $l1_hint, OptionalAttr: $l2_hint, OptionalAttr: $l3_hint); - let results = (outs XeTile_1DOr2DOr4DVector: $value); + let results = (outs XeTile_1DOr2DVector: $value); let assemblyFormat = [{ $tile `` `,` $mask attr-dict `:` qualified(type($tile)) `` `,` type($mask) `->` type($value) }]; @@ -678,7 +602,7 @@ def XeTile_StoreScatterOp: XeTile_Op<"store", [AllElementTypesMatch<["value", "t memory access so that it is safe to pass out-of-boundary addresses/offsets as long as they are masked. }]; - let arguments = (ins XeTile_1DOr2DOr4DVector: $value, + let arguments = (ins XeTile_1DOr2DVector: $value, XeTile: $tile, XeTile_MaskType: $mask, OptionalAttr: $l1_hint, diff --git a/include/imex/Dialect/XeTile/IR/XeTileTypes.td b/include/imex/Dialect/XeTile/IR/XeTileTypes.td index 7cd2dc4df..3e4ba173b 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileTypes.td +++ b/include/imex/Dialect/XeTile/IR/XeTileTypes.td @@ -110,13 +110,6 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface], return xetile::WorkGroupMapAttr(); } - mlir::DenseI64ArrayAttr getInnerBlocks() { - auto encoding = llvm::dyn_cast_if_present(getEncoding()); - if (encoding) - return encoding.getInnerBlocks(); - return mlir::DenseI64ArrayAttr(); - } - mlir::DenseI32ArrayAttr getOrder() { auto encoding = llvm::dyn_cast_if_present(getEncoding()); if (encoding) @@ -168,18 +161,13 @@ def XeTile_ScalarType : AnyTypeOf<[XeTile_IntType, XeTile_FloatType, Index]>; // define the source type for XeTile init_tile def XeTile_BaseAddrType : AnyTypeOf<[MemRefOf<[XeTile_ScalarType]>, UI64, UI32, I64, I32]>; -// input and output types needed for pack and unpack ops -// def XeTile_1DVector : VectorOfRankAndType<[1], [XeTile_ScalarType]>; +// define the value type for XeTile load_tile and store_tile op def XeTile_2DVector : VectorOfRankAndType<[2], [XeTile_ScalarType]>; -def XeTile_4DVector : VectorOfRankAndType<[4], [XeTile_ScalarType]>; // define the value type for XeTile load_gather and store_scatter op -def XeTile_1DOr2DOr4DVector: VectorOfRankAndType<[1, 2, 4], [XeTile_ScalarType]>; - -// define the value type for XeTile load_tile and store_tile op -def XeTile_2DOr4DVector: VectorOfRankAndType<[2, 4], [XeTile_ScalarType]>; +def XeTile_1DOr2DVector: VectorOfRankAndType<[1, 2], [XeTile_ScalarType]>; -def XeTile_MaskType: VectorOfRankAndType<[1, 2, 4], [I1]>; +def XeTile_MaskType: VectorOfRankAndType<[1, 2], [I1]>; // define the attribute type allowed for padding values for load op def XeTile_PaddingValueAttr : AnyAttrOf<[I32Attr, F32Attr]>; diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.h b/include/imex/Dialect/XeTile/Transforms/Passes.h index 17d26ad36..869732f3b 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.h +++ b/include/imex/Dialect/XeTile/Transforms/Passes.h @@ -30,8 +30,6 @@ class RewritePatternSet; namespace imex { -class XeTypeConverter; - //===----------------------------------------------------------------------===// /// XeTile passes. //===----------------------------------------------------------------------===// @@ -43,10 +41,6 @@ createXeTileBlockingPass(const std::string &device = "pvc"); std::unique_ptr createXeTileWgToSgPass(); std::unique_ptr createXeTileCanonicalizationPass(); -/// -void populateXeTileInitDuplicatePatterns(imex::XeTypeConverter &converter, - mlir::RewritePatternSet &patterns); - #define GEN_PASS_DECL_XETILEBLOCKING #define GEN_PASS_DECL_XETILECANONICALIZATION #define GEN_PASS_DECL_XETILEINITDUPLICATE diff --git a/include/imex/Dialect/XeTile/Transforms/XeTileOneToNConversion.h b/include/imex/Dialect/XeTile/Transforms/XeTileOneToNConversion.h index 71d04f4ce..961d97aac 100644 --- a/include/imex/Dialect/XeTile/Transforms/XeTileOneToNConversion.h +++ b/include/imex/Dialect/XeTile/Transforms/XeTileOneToNConversion.h @@ -36,17 +36,11 @@ namespace imex { -class XeOneToNTypeConverter : public imex::XeTypeConverter { +class XeOneToNTypeConverter : public mlir::TypeConverter { public: - XeOneToNTypeConverter(mlir::MLIRContext &context); - - std::optional - convertTileType(xetile::TileType tileTy, - llvm::SmallVectorImpl &resultTypes) override; + using mlir::TypeConverter::convertType; - std::optional - convertVectorType(mlir::VectorType vectorTy, - llvm::SmallVectorImpl &resultTypes) override; + XeOneToNTypeConverter(mlir::MLIRContext &context); mlir::LogicalResult computeTypeMapping(mlir::ValueRange original, mlir::ValueRange converted, @@ -116,14 +110,13 @@ class XeOneToNPatternRewriter : public mlir::PatternRewriter, }; template -class XeOneToNConversion : public XeConversionPattern { +class XeOneToNConversion : public XeConversionPattern { public: XeOneToNConversion(mlir::MLIRContext *context, XeOneToNTypeConverter &typeConverter, - TileUsageAnalysis &analysis, mlir::PatternBenefit benefit = 1) - : XeConversionPattern(typeConverter, analysis, - SourceOp::getOperationName(), benefit, context) {} + : XeConversionPattern(typeConverter, SourceOp::getOperationName(), + benefit, context) {} using RangeT = llvm::ArrayRef; using OpAdaptor = typename SourceOp::template GenericAdaptor; diff --git a/include/imex/Utils/XeCommon.h b/include/imex/Utils/XeCommon.h index c24b6b525..cc4ce9961 100644 --- a/include/imex/Utils/XeCommon.h +++ b/include/imex/Utils/XeCommon.h @@ -136,13 +136,6 @@ class TileUsageAnalysis { Usage[op] |= (uint)UsageType::DPAS_C; else op->emitOpError() << "unknown usage: " << idx; - } else if (auto unpack = - llvm::dyn_cast_if_present( - user)) { - q.push_back(unpack); - } else if (auto pack = - llvm::dyn_cast_if_present(user)) { - q.push_back(pack); } } } @@ -270,256 +263,24 @@ class TileUsageAnalysis { llvm::DenseMap Usage; }; -// This analysis is used to propagate the inner block size of an operator -// to its uses or users. Current implementation is to propagate the MMA -// size used by an MMA operator to the definition (InitTileOp) for its operands. -// TODO: This analysis can be extended to propagate the block size for other ops -// such that it can be used as a general analysis for other block size -// optimizations. -class PropagateAnalysis { -private: - llvm::DenseMap OpAttrMap; - -public: - PropagateAnalysis(mlir::Operation *op) { - op->walk([&](xetile::TileMMAOp op) { - mlir::Operation *operation = op.getOperation(); - for (auto value : operation->getOperands()) { - auto packOp = value.getDefiningOp(); - if (packOp) { - auto blkSZ = packOp.getInnerBlocksAttr(); - propagate(value, blkSZ); - } - } - }); - } - - bool maybeUpdated(mlir::Operation *op) const { - assert(op->getNumResults() == 1); - auto v = op->getResult(0); - return OpAttrMap.count(v); - } - - mlir::DenseI64ArrayAttr getValue(mlir::Value value) const { - auto it = OpAttrMap.find(value); - if (it != OpAttrMap.end()) - return it->second; - return {}; - } - - mlir::DenseI64ArrayAttr getValue(mlir::Operation *op) const { - assert(op->getNumResults() == 1); - auto v = op->getResult(0); - auto it = OpAttrMap.find(v); - if (it != OpAttrMap.end()) - return it->second; - return {}; - } - -private: - mlir::Operation *getDefineOrParentOp(mlir::Value value) { - if (llvm::isa(value)) - return value.getDefiningOp(); - if (auto arg = llvm::dyn_cast_or_null(value)) - return arg.getOwner()->getParentOp(); - return nullptr; - }; - - mlir::Value getOperandForArg(mlir::scf::ForOp &forOp, mlir::Value &value) { - auto arg = llvm::dyn_cast(value); - if (arg && arg.getArgNumber() >= forOp.getNumInductionVars()) { - auto &iterOperand = *forOp.getTiedLoopInit(arg); - auto numCtrlOperands = forOp.getNumControlOperands(); - auto operandIdx = iterOperand.getOperandNumber(); - return forOp.getInitArgs()[operandIdx - numCtrlOperands]; - } - return mlir::Value(); - }; - - void propagate(mlir::Value start, mlir::DenseI64ArrayAttr attr) { - llvm::SmallVector queue; - if (bool(start)) - queue.push_back(start); - - while (queue.size()) { - auto value = queue.pop_back_val(); - if (!bool(value)) - continue; - - auto *op = getDefineOrParentOp(value); - - // stop when meet a function or ops, e.g., arith.truncf. - // since their source and results could have different bitwidth, - // in which case the block size cannot be propagated. - if (!op || llvm::isa(op) || - llvm::isa(op)) - continue; - - OpAttrMap[value] = attr; - - if (auto forOp = llvm::dyn_cast(op)) { - auto opr = getOperandForArg(forOp, value); - if (bool(opr)) - queue.push_back(opr); - } else if (op->getNumOperands() == 1) { - queue.push_back(op->getOperand(0)); - } - } - } -}; - std::pair encodeVectorType(mlir::ConversionPatternRewriter &rewriter, mlir::VectorType type, bool use64bitData = false, bool enforceInteger = false, bool keepF16 = false); -mlir::VectorType encodeVectorTypeTo(mlir::VectorType currentVecType, - mlir::Type toElemType); - unsigned encodeDataum(mlir::Type type); unsigned encodeOpcode(mlir::arith::AtomicRMWKind kind); -// L1 and L3 Cache Policies for Load Operation -// L1 Cache Policies: Uncached (UC), Cached (C), Cache Streaming (S), -// Invalidate-After-Read (IAR) L3 Cache Policies: Uncached (UC), Cached (C) -#define L1UC_L3UC 1 -#define L1UC_L3C 2 -#define L1C_L3UC 3 -#define L1C_L3C 4 -#define L1S_L3UC 5 -#define L1S_L3C 6 -#define L1IAR_L3C 7 - -// L1 and L3 Cache Policies for Store operation -// L1 Cache Policies: Uncached (UC), Write-Through (WT), Write-Back (WB), -// Streaming (S) L3 Cache Policies: Uncached (UC), Cached (WB) -#define L1UC_L3WB 2 -#define L1WT_L3UC 3 -#define L1WT_L3WB 4 -#define L1S_L3UC 5 -#define L1S_L3WB 6 -#define L1WB_L3WB 7 - -template unsigned encodeCacheHint(OpType op) { - auto l1hint = op.getL1Hint(); - auto l3hint = op.getL3Hint(); - - constexpr bool isStore = std::is_same_v || - std::is_same_v; - unsigned cacheHint = L1UC_L3UC; - -#define SET_CACHEVALUE(hint, cacheHintVal) \ - hint.has_value() ? hint.value() : cacheHintVal - - if constexpr (!isStore) { - - auto l1CacheValue = SET_CACHEVALUE(l1hint, CachePolicy::UNCACHED); - auto l3CacheValue = SET_CACHEVALUE(l3hint, CachePolicy::UNCACHED); - -// Setting Cache policy override based on L3 Uncached/Cached value for Load -// operation -#define SET_L1L3_CACHEREADHINT(cacheHint, l3CacheValue, uncachedVal, \ - cachedVal) \ - if (l3CacheValue == CachePolicy::UNCACHED) \ - cacheHint = uncachedVal; \ - else if (l3CacheValue == CachePolicy::CACHED) \ - cacheHint = cachedVal; - - switch (l1CacheValue) { - case CachePolicy::UNCACHED: - SET_L1L3_CACHEREADHINT(cacheHint, l3CacheValue, L1UC_L3UC, L1UC_L3C); - break; - case CachePolicy::CACHED: - SET_L1L3_CACHEREADHINT(cacheHint, l3CacheValue, L1C_L3UC, L1C_L3C); - break; - case CachePolicy::STREAMING: - SET_L1L3_CACHEREADHINT(cacheHint, l3CacheValue, L1S_L3UC, L1S_L3C); - break; - case CachePolicy::READ_INVALIDATE: - if (l3CacheValue == CachePolicy::CACHED) - cacheHint = L1IAR_L3C; - break; - default: - llvm_unreachable("Invalid Cache Policy for Read.\n"); - } - - } else { - auto l1CacheValue = SET_CACHEVALUE(l1hint, CachePolicy::UNCACHED); - auto l3CacheValue = SET_CACHEVALUE(l3hint, CachePolicy::UNCACHED); - -// Setting Cache policy override based on L3 Uncached/Write-Back value for Store -// operation -#define SET_L1L3_CACHEWRITEHINT(cacheHint, l3CacheValue, uncachedVal, \ - cachedVal) \ - if (l3CacheValue == CachePolicy::UNCACHED) \ - cacheHint = uncachedVal; \ - else if (l3CacheValue == CachePolicy::WRITE_BACK) \ - cacheHint = cachedVal; - - switch (l1CacheValue) { - case CachePolicy::UNCACHED: - SET_L1L3_CACHEWRITEHINT(cacheHint, l3CacheValue, L1UC_L3UC, L1UC_L3WB); - break; - case CachePolicy::WRITE_THROUGH: - SET_L1L3_CACHEWRITEHINT(cacheHint, l3CacheValue, L1WT_L3UC, L1WT_L3WB); - break; - case CachePolicy::STREAMING: - SET_L1L3_CACHEWRITEHINT(cacheHint, l3CacheValue, L1S_L3UC, L1S_L3WB); - break; - case CachePolicy::WRITE_BACK: - if (l3CacheValue == CachePolicy::WRITE_BACK) - cacheHint = L1WB_L3WB; - break; - default: - llvm_unreachable("Invalid Cache Policy for Write.\n"); - } - } - return cacheHint; -} -class XeTypeConverter : public mlir::TypeConverter { -public: - // friend class XeConversionPattern; - using mlir::TypeConverter::convertType; - - XeTypeConverter(mlir::MLIRContext &context) { - addConversion([&](xetile::TileType tileTy, - llvm::SmallVectorImpl &resultTypes) - -> std::optional { - return convertTileType(tileTy, resultTypes); - }); - - addConversion([&](mlir::VectorType vectorTy, - llvm::SmallVectorImpl &resultTypes) - -> std::optional { - return convertVectorType(vectorTy, resultTypes); - }); - } - - virtual std::optional - convertTileType(xetile::TileType tileTy, - llvm::SmallVectorImpl &resultTypes) { - llvm_unreachable("Pending Implementation for convertTileType."); - } - - virtual std::optional - convertVectorType(mlir::VectorType vectorTy, - llvm::SmallVectorImpl &resultTypes) { - llvm_unreachable("Pending Implementation for convertVectorType."); - } -}; - // A simple mlir::RewritePattern wrapper with methods for accessing UsageType -template class XeConversionPattern : public mlir::RewritePattern { public: using mlir::RewritePattern::RewritePattern; template - XeConversionPattern(imex::XeTypeConverter &typeConverter, AnalysisT &analysis, - Args &&...args) + XeConversionPattern(mlir::TypeConverter &typeConverter, Args &&...args) : mlir::RewritePattern(std::forward(args)...), - typeConverter(typeConverter), analysis(analysis) {} + typeConverter(typeConverter) {} virtual mlir::LogicalResult matchAndRewrite(mlir::Operation *op, @@ -527,7 +288,7 @@ class XeConversionPattern : public mlir::RewritePattern { llvm_unreachable("must override matchAndRewrite or a rewrite method"); }; - imex::XeTypeConverter &getTypeConverter() const { return typeConverter; } + mlir::TypeConverter &getTypeConverter() const { return typeConverter; } template std::enable_if_t::value, @@ -537,103 +298,9 @@ class XeConversionPattern : public mlir::RewritePattern { } protected: - imex::XeTypeConverter &typeConverter; - AnalysisT &analysis; - - template >> - mlir::DenseI64ArrayAttr getValue(mlir::Operation *op) const { - if (op) - return llvm::cast(analysis).getValue(op); - return {}; - } - - mlir::DenseI64ArrayAttr getValue(mlir::Value value) const { - return llvm::cast(analysis).getValue(value); - } - - template >> - bool isForDPASA(imex::xetile::LoadTileOp op) const { - return llvm::cast(analysis).isForDPASA(op); - } - - template >> - bool isForDPASB(imex::xetile::LoadTileOp op) const { - return llvm::cast(analysis).isForDPASB(op); - } - - template >> - bool isForDPASC(imex::xetile::LoadTileOp op) const { - return llvm::cast(analysis).isForDPASC(op); - } - - template >> - bool isForLoad(imex::xetile::InitTileOp op) const { - return llvm::cast(analysis).isForLoad(op); - } - - template >> - bool isForStore(imex::xetile::InitTileOp op) const { - return llvm::cast(analysis).isForStore(op); - } - - template >> - bool isForPrefetch(imex::xetile::InitTileOp op) const { - return llvm::cast(analysis).isForPrefetch(op); - } - - template >> - bool isForAtomicRMW(imex::xetile::InitTileOp op) const { - return llvm::cast(analysis).isForAtomicRMW(op); - } - - template >> - bool isForLoadAndPrefetch(imex::xetile::InitTileOp op) const { - return llvm::cast(analysis).isForLoadAndPrefetch(op); - } - - template >> - bool isForLoadAndStore(imex::xetile::InitTileOp op) const { - return llvm::cast(analysis).isForLoadAndStore(op); - } - - template >> - bool isForLoadAndAtomicRMW(imex::xetile::InitTileOp op) const { - return llvm::cast(analysis).isForLoadAndAtomicRMW(op); - } - - template >> - bool isForAtomicRMWAndStore(imex::xetile::InitTileOp op) const { - return llvm::cast(analysis).isForAtomicRMWAndStore(op); - } + mlir::TypeConverter &typeConverter; }; -/// Clone `shape` with the last two elements swapped. -template -llvm::SmallVector swapLastTwoElements(llvm::ArrayRef shape) { - assert(shape.size() >= 2 && "shape must be at least 2D"); - llvm::SmallVector result(shape.begin(), shape.end()); - auto size = result.size(); - std::swap(result[size - 1], result[size - 2]); - return result; -} - -/// Creates the default strides for the given `shape`. Example: -/// input shape = 2x3x4x5 -/// output strides = 60x20x5x1 -llvm::SmallVector defaultStrides(llvm::ArrayRef shape); - /// Checks if the given `type` is a 1-D vector type that requires VectorAnyINTEL /// capability. In other words, the vector size is not supported by SPIR-V. /// SPIR-V only supports 2, 3, 4, 8, 16 elements (8 and 16 with Vector16 diff --git a/lib/Dialect/XeTile/IR/XeTileDialect.cpp b/lib/Dialect/XeTile/IR/XeTileDialect.cpp index e3f46895a..53d627e7d 100644 --- a/lib/Dialect/XeTile/IR/XeTileDialect.cpp +++ b/lib/Dialect/XeTile/IR/XeTileDialect.cpp @@ -110,18 +110,14 @@ mlir::LogicalResult WorkGroupMapAttr::verify( return mlir::success(); } -mlir::LogicalResult XeTileAttr::verify( - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - ::imex::xetile::SubGroupMapAttr sg_map, xetile::WorkGroupMapAttr wg_map, - mlir::DenseI32ArrayAttr order, mlir::DenseI64ArrayAttr inner_blocks, - mlir::Attribute MemorySpace, mlir::BoolAttr scattered) { - +mlir::LogicalResult +XeTileAttr::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + ::imex::xetile::SubGroupMapAttr sg_map, + xetile::WorkGroupMapAttr wg_map, + mlir::DenseI32ArrayAttr order, mlir::Attribute MemorySpace, + mlir::BoolAttr scattered) { if (order != mlir::DenseI32ArrayAttr() && order.size() != 2) emitError() << "expect integer array of size 2 for order"; - if (inner_blocks != mlir::DenseI64ArrayAttr() && - (inner_blocks.size() > 0 && inner_blocks.size() != 2)) - emitError() << "expect integer array of size 2 for non empty inner_blocks " - "attribute"; return mlir::success(); } diff --git a/lib/Dialect/XeTile/IR/XeTileOps.cpp b/lib/Dialect/XeTile/IR/XeTileOps.cpp index 78afb899f..bb19bea4d 100644 --- a/lib/Dialect/XeTile/IR/XeTileOps.cpp +++ b/lib/Dialect/XeTile/IR/XeTileOps.cpp @@ -248,89 +248,6 @@ void InitTileOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, {} /* static strides */, indices); } -bool verifyInnerBlocksWithVecShape(mlir::DenseI64ArrayAttr &innerBlocks, - llvm::ArrayRef &vecShape, - llvm::ArrayRef &tileShape) { - if (!(vecShape[2] == innerBlocks[0] && vecShape[3] == innerBlocks[1] && - ((tileShape[0] / innerBlocks[0]) == vecShape[0]) && - ((tileShape[1] / innerBlocks[1]) == vecShape[1]))) - return false; - - return true; -} - -mlir::LogicalResult LoadTileOp::verify() { - auto encoding = getSource().getType().getEncoding(); - auto tileShape = getSource().getType().getShape(); - auto vecShape = getResult().getType().getShape(); - - // inner_blocks may or maynot be present in this op. - auto innerBlocks = mlir::DenseI64ArrayAttr(); - if (encoding) - innerBlocks = mlir::dyn_cast(encoding).getInnerBlocks(); - - // if inner_blocks is not present in the tile_attr, the output of the load - // must be 2D and tile shape and vector output shape must match - if (innerBlocks == mlir::DenseI64ArrayAttr()) - if (!vecShape.equals(tileShape)) - return emitOpError("Output shape must match the tile shape."); - - if (innerBlocks != mlir::DenseI64ArrayAttr() && innerBlocks.size() > 0) { - // if inner_blocks is present in the tile_attr, the output of the load - // must be 4D - if (vecShape.size() != 4) - return emitOpError( - "output must be a 4D vector if inner_blocks is used in tile_attr."); - // and, tile shape, output vector shape must be consistent with inner_blocks - if (!verifyInnerBlocksWithVecShape(innerBlocks, vecShape, tileShape)) - return emitOpError( - "shapes of the source tile, output value and inner_blocks must " - "satisfy : " - "valueShape[0] == tileShape[0]/innerBlocks[0] && valueShape[1] == " - "tileShape[1]/innerBlocks[1] && " - "valueShape[2] == innerBlocks[0] && valueShape[3] == " - "innerBlocks[1]."); - } - return mlir::success(); -} - -mlir::LogicalResult StoreTileOp::verify() { - auto encoding = getTile().getType().getEncoding(); - if (!encoding) - return mlir::success(); - - auto tileAttr = mlir::dyn_cast(encoding); - auto innerBlocks = tileAttr.getInnerBlocks(); - auto tileShape = getTile().getType().getShape(); - - // if inner_blocks is not present in the tile_attr, the stored value - // must be 2D - if (innerBlocks == mlir::DenseI32ArrayAttr() && - getValue().getType().getShape().size() != 2) - return emitOpError( - "value must be a 2D vector if inner_blocks is not used in tile_attr."); - - if (innerBlocks != mlir::DenseI32ArrayAttr() && innerBlocks.size() > 0) { - auto vecShape = getValue().getType().getShape(); - // if inner_blocks is present in the tile_attr, the stored value - // must be 4D - if (vecShape.size() != 4) - return emitOpError( - "value must be a 4D vector if inner_blocks is used in tile_attr."); - // and, tile shape, input vector shape must be consistent with inner_blocks - if (!verifyInnerBlocksWithVecShape(innerBlocks, vecShape, tileShape)) - return emitOpError( - "shapes of the destination tile, value and inner_blocks must " - "satisfy : " - "valueShape[0] == tileShape[0]/innerBlocks[0] && valueShape[1] == " - "tileShape[1]/innerBlocks[1] && " - "valueShape[2] == innerBlocks[0] && valueShape[3] == " - "innerBlocks[1]."); - } - - return mlir::success(); -} - mlir::LogicalResult TileMMAOp::verify() { int64_t aRank = getAType().getRank(); int64_t bRank = getBType().getRank(); @@ -356,27 +273,14 @@ mlir::LogicalResult TileMMAOp::verify() { outElemType)) return emitOpError("C and output vector must have the same type."); - auto check4DMmaShapes = [](llvm::ArrayRef &A, - llvm::ArrayRef &B, - llvm::ArrayRef &Out) -> bool { - return A[1] == B[0] && A[3] == B[2] && Out[0] == A[0] && Out[1] == B[1] && - Out[2] == A[2] && Out[3] == B[3]; - }; - - auto check2DMmaShapes = [](llvm::ArrayRef &A, - llvm::ArrayRef &B, - llvm::ArrayRef &Out) -> bool { + auto checkMmaShapes = [](llvm::ArrayRef &A, + llvm::ArrayRef &B, + llvm::ArrayRef &Out) -> bool { return A[1] == B[0] && Out[0] == A[0] && Out[1] == B[1]; }; - // check mma shapes for 4D case - if (aRank == 4 && !check4DMmaShapes(aShape, bShape, outShape)) - return emitOpError("incompatible A, B and output sizes for 4D tile mma op. " - "4D tile mma should have the shape (m x k x Bm x Bk) x " - "(k x n x Bk x Bn) = (m x n x Bm x Bn)."); - - // check mma shape for 2D case - if (aRank == 2 && !check2DMmaShapes(aShape, bShape, outShape)) + // check mma shape + if (aRank == 2 && !checkMmaShapes(aShape, bShape, outShape)) return emitOpError( "incompatible A, B and output sizes for 2D tile mma op. " "2D tile mma should have the shape (m x k) x (k x n) = (m x n)."); @@ -389,96 +293,6 @@ mlir::LogicalResult TileMMAOp::verify() { return mlir::success(); } -mlir::LogicalResult TilePackOp::verify() { - auto inVecShape = getInVec().getType().getShape(); - auto outVecShape = getOutVec().getType().getShape(); - auto innerBlocks = getInnerBlocks(); - auto inElemTy = getInVec().getType().getElementType(); - auto outElemTy = getOutVec().getType().getElementType(); - - // input and output vector element types must match - if (inElemTy != outElemTy) - return emitOpError("input and output vector element type mismatch."); - - // innermost 2 dimensions of the output vector must satisfy: - // outVecShape[2] == innerBlocks[0] - // outVecShape[3] == innerBlocks[1] - if (!(outVecShape[2] == innerBlocks[0] && outVecShape[3] == innerBlocks[1])) - return emitOpError( - "innermost 2 dimensions of output vector must satisfy : " - "outVecShape[2] == innerBlocks[0] && outVecShape[3] == innerBlocks[1]"); - - // outermost 2 dimensions of the output vector must satisfy: - // outVecShape[0] == inVecShape[0]/innerBlocks[0] - // outVecShape[1] == inVecShape[1]/innerBlocks[1] - if (!(outVecShape[0] == inVecShape[0] / innerBlocks[0] && - outVecShape[1] == inVecShape[1] / innerBlocks[1])) - return emitOpError( - "outermost 2 dimensions of the output vector must satisfy : " - "outVecShape[0] == inVecShape[0]/innerBlocks[0] && " - "outVecShape[1] == inVecShape[1]/innerBlocks[1]"); - - return mlir::success(); -} - -mlir::OpFoldResult TilePackOp::fold(FoldAdaptor /*adaptor*/) { - mlir::Value in = this->getInVec(); - if (auto unpack = in.getDefiningOp()) { - mlir::Value src = unpack.getInVec(); - if (src.getType() != this->getType() || - unpack.getInnerBlocks() != this->getInnerBlocks()) - return nullptr; - - return src; - } - return nullptr; -} - -mlir::LogicalResult TileUnpackOp::verify() { - auto inVecShape = getInVec().getType().getShape(); - auto outVecShape = getOutVec().getType().getShape(); - auto innerBlocks = getInnerBlocks(); - auto inElemTy = getInVec().getType().getElementType(); - auto outElemTy = getOutVec().getType().getElementType(); - - // input and output vector element types must match - if (inElemTy != outElemTy) - return emitOpError("input and output vector element type mismatch."); - - // innermost 2 dimensions of the input vector must satisfy - // outVecShape[2] == innerBlocks[0] - // outVecShape[3] == innerBlocks[1] - if (!(inVecShape[2] == innerBlocks[0] && inVecShape[3] == innerBlocks[1])) - return emitOpError( - "innermost 2 dimensions of the input vector must satisfy : " - "inVecShape[2] == innerBlocks[0] && " - "inVecShape[3] == innerBlocks[1]"); - - // output vector must satisfy : - // outVecShape[0] == inVecShape[0] * innerBlocks[0] - // outVecShape[1] == inVecShape[1] * innerBlocks[1] && - if (!(outVecShape[0] == inVecShape[0] * innerBlocks[0] && - outVecShape[1] == inVecShape[1] * innerBlocks[1])) - return emitOpError("output vector must satisfy : " - "outVecShape[0] == inVecShape[0] * innerBlocks[0] && " - "outVecShape[1] == inVecShape[1] * innerBlocks[1]"); - - return mlir::success(); -} - -mlir::OpFoldResult TileUnpackOp::fold(FoldAdaptor /*adaptor*/) { - mlir::Value in = this->getInVec(); - if (auto pack = in.getDefiningOp()) { - mlir::Value src = pack.getInVec(); - if (src.getType() != this->getType() || - pack.getInnerBlocks() != this->getInnerBlocks()) - return nullptr; - - return src; - } - return nullptr; -} - mlir::LogicalResult TransposeOp::verify() { auto srcShape = getVector().getType().getShape(); auto resShape = getResult().getType().getShape(); diff --git a/lib/Dialect/XeTile/Transforms/Canonicalization.cpp b/lib/Dialect/XeTile/Transforms/Canonicalization.cpp index 67086951b..202b977ba 100644 --- a/lib/Dialect/XeTile/Transforms/Canonicalization.cpp +++ b/lib/Dialect/XeTile/Transforms/Canonicalization.cpp @@ -404,8 +404,7 @@ struct XeTileCanonicalizationPass final auto newAttr = imex::xetile::XeTileAttr::get( tileTy.getContext(), tileTy.getSgMap(), tileTy.getWgMap(), mlir::DenseI32ArrayAttr::get(tileTy.getContext(), {1, 0}), - tileTy.getInnerBlocks(), tileTy.getMemorySpace(), - tileTy.getScatterAttr()); + tileTy.getMemorySpace(), tileTy.getScatterAttr()); return imex::xetile::TileType::get( swapLastTwoElems(tileTy.getShape()), tileTy.getElementType(), diff --git a/lib/Dialect/XeTile/Transforms/WgToSg.cpp b/lib/Dialect/XeTile/Transforms/WgToSg.cpp index 6441b1813..ce225a53a 100644 --- a/lib/Dialect/XeTile/Transforms/WgToSg.cpp +++ b/lib/Dialect/XeTile/Transforms/WgToSg.cpp @@ -733,8 +733,7 @@ class WGToSGXeTileConvertLayout auto order = mlir::DenseI32ArrayAttr::get(op.getContext(), {1, 0}); auto attr = imex::xetile::XeTileAttr::get( op.getContext(), nullptr /*sgMap*/, nullptr /*wgMap*/, - order /*order*/, nullptr /*innerblocks*/, memoryScopeAttr /*memoryscope*/, - nullptr /*scatterAttr*/); + order /*order*/, memoryScopeAttr /*memoryscope*/, nullptr /*scatterAttr*/); xetile::TileType srcTileTy = imex::xetile::TileType::get({srcMapSgData[0], srcMapSgData[1]}, elemTy, attr); @@ -912,18 +911,16 @@ void analyzeInitTileOps(mlir::Operation *op) { } void populateXeTileWgToSgPatterns(imex::XeOneToNTypeConverter &converter, - mlir::RewritePatternSet &patterns, - TileUsageAnalysis &analysis) { + mlir::RewritePatternSet &patterns) { patterns.insert(patterns.getContext(), converter, analysis); + WGToSGArithTruncFOpPattern>(patterns.getContext(), converter); patterns.insert, WGToSGElementWiseOpPattern, - WGToSGArithConstantOpPattern>(patterns.getContext(), - converter, analysis); + WGToSGArithConstantOpPattern>(patterns.getContext(), converter); } // Transforms WG XeTile IR to SG XeTile @@ -944,7 +941,6 @@ class XeTileWgToSgPass return signalPassFailure(); } - auto &analysis = getAnalysis(); mlir::Operation *op = getOperation(); // Run the analysis to find the candidates for the transformation analyzeInitTileOps(op); @@ -1049,7 +1045,7 @@ class XeTileWgToSgPass target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - populateXeTileWgToSgPatterns(typeConverter, patterns, analysis); + populateXeTileWgToSgPatterns(typeConverter, patterns); if (mlir::failed( mlir::applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); diff --git a/lib/Dialect/XeTile/Transforms/XeTileOneToNConversion.cpp b/lib/Dialect/XeTile/Transforms/XeTileOneToNConversion.cpp index e185156ce..6413fcdce 100644 --- a/lib/Dialect/XeTile/Transforms/XeTileOneToNConversion.cpp +++ b/lib/Dialect/XeTile/Transforms/XeTileOneToNConversion.cpp @@ -79,8 +79,7 @@ buildUnrealizedBackwardsCasts(mlir::ValueRange convertedValues, return recastValues; } -XeOneToNTypeConverter::XeOneToNTypeConverter(mlir::MLIRContext &context) - : XeTypeConverter(context) { +XeOneToNTypeConverter::XeOneToNTypeConverter(mlir::MLIRContext &context) { targetOp = nullptr; addConversion( @@ -106,31 +105,6 @@ XeOneToNTypeConverter::XeOneToNTypeConverter(mlir::MLIRContext &context) }); } -std::optional XeOneToNTypeConverter::convertTileType( - xetile::TileType tileTy, llvm::SmallVectorImpl &resultTypes) { - llvm::dbgs() - << "convertTileType is disabled, since there is no unique " - << "way to convert an XeTile::TileType into mlir::xegpu::TensorDescType " - << "becasue of array_length selection.\n"; - return std::nullopt; -} - -std::optional XeOneToNTypeConverter::convertVectorType( - mlir::VectorType vectorTy, llvm::SmallVectorImpl &resultTypes) { - if (vectorTy.getRank() == 4) { - auto shape = vectorTy.getShape(); - auto vecTy = - mlir::VectorType::get({shape[2], shape[3]}, vectorTy.getElementType()); - auto numElements = shape[0] * shape[1]; - resultTypes.assign(numElements, vecTy); - return mlir::success(); - } else if (vectorTy.getRank() == 2) { - resultTypes.push_back(vectorTy); - return mlir::success(); - } - return std::nullopt; -} - // It computes the mapping between types orginal values and // converted values. The standard type conversion method doesn't // work here because a TileType could have multiple decomposions diff --git a/lib/Utils/XeCommon.cpp b/lib/Utils/XeCommon.cpp index 190802a26..08e7e8695 100644 --- a/lib/Utils/XeCommon.cpp +++ b/lib/Utils/XeCommon.cpp @@ -8,8 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file implements XeTypeConverter and some other -/// routines used by Xe related dialects. +/// This file implements some routines used by Xe related dialects. /// //===----------------------------------------------------------------------===// #include @@ -116,20 +115,6 @@ encodeVectorType(mlir::ConversionPatternRewriter &rewriter, return {resStr, resVecType}; } -/// @brief -/// We have to use i32 for intrinsic calls like llvm_genx_raw_send2_*, if we -/// want to get the original element type (e.g., f16) as the result of a load, -/// we have to encode the resulting i32 vector back to it. -mlir::VectorType encodeVectorTypeTo(mlir::VectorType currentVecType, - mlir::Type toElemType) { - auto elemType = currentVecType.getElementType(); - auto currentbitWidth = elemType.getIntOrFloatBitWidth(); - auto newBitwidth = toElemType.getIntOrFloatBitWidth(); - const int size = - currentVecType.getNumElements() * currentbitWidth / newBitwidth; - return mlir::VectorType::get(size, toElemType); -} - unsigned encodeDataum(mlir::Type type) { switch (type.getIntOrFloatBitWidth()) { case 8: @@ -193,20 +178,6 @@ unsigned encodeOpcode(mlir::arith::AtomicRMWKind kind) { return encode; } -/// Creates the default strides for the given `shape`. Example: -/// input shape = 2x3x4x5 -/// output strides = 60x20x5x1 -llvm::SmallVector defaultStrides(llvm::ArrayRef shape) { - int64_t stride = 1; - llvm::SmallVector strides; - for (int64_t size : llvm::reverse(shape)) { - strides.push_back(stride); - stride *= size; - } - std::reverse(strides.begin(), strides.end()); - return strides; -} - mlir::TypedValue stack(mlir::Value vecUp, mlir::Value vecDown, mlir::Location loc, mlir::OpBuilder &builder) { diff --git a/test/Dialect/XeTile/IR/canonicalize.mlir b/test/Dialect/XeTile/IR/canonicalize.mlir deleted file mode 100644 index 9b53d7ed9..000000000 --- a/test/Dialect/XeTile/IR/canonicalize.mlir +++ /dev/null @@ -1,20 +0,0 @@ -// RUN: imex-opt %s -canonicalize="test-convergence" --split-input-file | FileCheck %s - - -// CHECK-LABEL: func @test_pack_unpack_chain -// CHECK-SAME: (%[[SRC:.*]]: vector<32x64xf16>) -// CHECK: return %[[SRC]] : vector<32x64xf16> -func.func @test_pack_unpack_chain(%source : vector<32x64xf16>) -> vector<32x64xf16> { - %1 = xetile.tile_pack %source {inner_blocks = array} : vector<32x64xf16> -> vector<2x4x16x16xf16> - %2 = xetile.tile_unpack %1 {inner_blocks = array} : vector<2x4x16x16xf16> -> vector<32x64xf16> - return %2 : vector<32x64xf16> -} - -// CHECK-LABEL: func @test_unpack_pack_chain -// CHECK-SAME: (%[[SRC:.*]]: vector<2x4x16x16xf16>) -// CHECK: return %[[SRC]] : vector<2x4x16x16xf16> -func.func @test_unpack_pack_chain(%source : vector<2x4x16x16xf16>) -> vector<2x4x16x16xf16> { - %1 = xetile.tile_unpack %source {inner_blocks = array} : vector<2x4x16x16xf16> -> vector<32x64xf16> - %2 = xetile.tile_pack %1 {inner_blocks = array} : vector<32x64xf16> -> vector<2x4x16x16xf16> - return %2 : vector<2x4x16x16xf16> -} diff --git a/test/Dialect/XeTile/IR/invalid.mlir b/test/Dialect/XeTile/IR/invalid.mlir index f58fb45ea..1717b39d7 100644 --- a/test/Dialect/XeTile/IR/invalid.mlir +++ b/test/Dialect/XeTile/IR/invalid.mlir @@ -61,32 +61,6 @@ func.func @init_tile_static_memref_with_invalid_dynamic_shape(%source : memref<1 : memref<1024x1024xf32> -> !xetile.tile<64x64xf32> } -// ----- -func.func @load_tile_incompatible_inner_blocks(%src : !xetile.tile<64x64xf16, - #xetile.tile_attr>) { - // shapes of source tile and output value of load must be consistent with inner_blocks - // expected-error@+1 {{shapes of the source tile, output value and inner_blocks must satisfy : valueShape[0] == tileShape[0]/innerBlocks[0] && valueShape[1] == tileShape[1]/innerBlocks[1] && valueShape[2] == innerBlocks[0] && valueShape[3] == innerBlocks[1].}} - %1 = xetile.load_tile %src : !xetile.tile<64x64xf16, #xetile.tile_attr> - -> vector<8x2x8x16xf16> -} - -// ----- -func.func @store_tile_incompatible_inner_blocks(%dst : !xetile.tile<64x64xf16, - #xetile.tile_attr>, %value : vector<8x4x8x8xf16>) { - // shapes od destination tile and input value of store must be consistent with inner_blocks - // expected-error@+1 {{shapes of the destination tile, value and inner_blocks must satisfy : valueShape[0] == tileShape[0]/innerBlocks[0] && valueShape[1] == tileShape[1]/innerBlocks[1] && valueShape[2] == innerBlocks[0] && valueShape[3] == innerBlocks[1].}} - xetile.store_tile %value, %dst : vector<8x4x8x8xf16>, !xetile.tile<64x64xf16, #xetile.tile_attr> -} - -// ----- -func.func @tile_mma_incompatible_ranks(%a_vec : vector<8x8x8x8xf32>, - %b_vec : vector<8x8xf32>, %c_vec : vector<8x8x8x8xf32>) { - // the two input vectors must have the same rank - // expected-error@+1 {{A and B inputs must have the same rank.}} - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec : vector<8x8x8x8xf32>, vector<8x8xf32>, vector<8x8x8x8xf32> - -> vector<8x8x8x8xf32> -} - // ----- func.func @tile_mma_input_elem_type_mismatch(%a_vec : vector<8x8xf32>, %b_vec : vector<8x8xf16>, %c_vec : vector<8x8xf32>) { @@ -104,16 +78,7 @@ func.func @tile_mma_output_elem_type_mismatch(%a_vec : vector<8x8xf32>, } // ----- -func.func @tile_mma_incompatible_mma_shapes_4d(%a_vec : vector<8x16x8x32xf16>, - %b_vec : vector<16x8x8x8xf16>, %c_vec : vector<8x8x8x8xf32>) { - // the two input vectors must have the same element type - // expected-error@+1 {{incompatible A, B and output sizes for 4D tile mma op. 4D tile mma should have the shape (m x k x Bm x Bk) x (k x n x Bk x Bn) = (m x n x Bm x Bn).}} - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec - : vector<8x16x8x32xf16>, vector<16x8x8x8xf16>, vector<8x8x8x8xf32> -> vector<8x8x8x8xf32> -} - -// ----- -func.func @tile_mma_incompatible_mma_shapes_2d(%a_vec : vector<8x16xf16>, +func.func @tile_mma_incompatible_mma_shapes(%a_vec : vector<8x16xf16>, %b_vec : vector<8x8xf16>, %c_vec : vector<8x8xf32>) { // the two input vectors must have the same element type // expected-error@+1 {{incompatible A, B and output sizes for 2D tile mma op. 2D tile mma should have the shape (m x k) x (k x n) = (m x n).}} @@ -128,48 +93,6 @@ func.func @tile_mma_input_c_shape_mismatch(%a_vec : vector<8x16xf16>, %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec : vector<8x16xf16>, vector<16x8xf16>, vector<16x8xf32> -> vector<8x8xf32> } -// ----- -func.func @tile_pack_invalid_element_types(%in : vector<32x64xf16>) { - // input and output element types must match - // expected-error@+1 {{input and output vector element type mismatch.}} - %out = xetile.tile_pack %in {inner_blocks = array} : vector<32x64xf16> -> vector<4x4x8x16xf32> -} - -// ----- -func.func @tile_pack_invalid_inner_blocks(%in : vector<32x64xf16>) { - // innermost two dims of output must match inner_blocks shape - // expected-error@+1 {{innermost 2 dimensions of output vector must satisfy : outVecShape[2] == innerBlocks[0] && outVecShape[3] == innerBlocks[1]}} - %out = xetile.tile_pack %in {inner_blocks = array} : vector<32x64xf16> -> vector<4x4x8x16xf16> -} - -// ----- -func.func @tile_pack_invalid_output_shape(%in : vector<32x64xf16>) { - // outermost 2 dims of output must be consistent with input shape. - // expected-error@+1 {{outermost 2 dimensions of the output vector must satisfy : outVecShape[0] == inVecShape[0]/innerBlocks[0] && outVecShape[1] == inVecShape[1]/innerBlocks[1]}} - %out = xetile.tile_pack %in {inner_blocks = array} : vector<32x64xf16> -> vector<4x4x16x16xf16> -} - -// ----- -func.func @tile_unpack_invalid_element_types(%in : vector<4x4x8x16xf16>) { - // input and output element types must match - // expected-error@+1 {{input and output vector element type mismatch.}} - %out = xetile.tile_unpack %in {inner_blocks = array} : vector<4x4x8x16xf16> -> vector<32x64xf32> -} - -// ----- -func.func @tile_unpack_invalid_inner_blocks(%in : vector<4x4x8x16xf16>) { - // innermost two dims of input must match inner_blocks shape - // expected-error@+1 {{innermost 2 dimensions of the input vector must satisfy : inVecShape[2] == innerBlocks[0] && inVecShape[3] == innerBlocks[1]}} - %out = xetile.tile_unpack %in {inner_blocks = array} : vector<4x4x8x16xf16> -> vector<32x64xf16> -} - -// ----- -func.func @tile_unpack_invalid_output_shape(%in : vector<4x4x16x16xf16>) { - // output shape must be consistent with inputshape and inner_blocks - // expected-error@+1 {{output vector must satisfy : outVecShape[0] == inVecShape[0] * innerBlocks[0] && outVecShape[1] == inVecShape[1] * innerBlocks[1]}} - %out = xetile.tile_unpack %in {inner_blocks = array} : vector<4x4x16x16xf16> -> vector<32x64xf16> -} - // ----- func.func @test_init_tile_with_mismatch_memory_space(%a: memref<1024x1024xf16, 3>) { // expected-error@+1 {{memory space of the tile doesn't match with the source}} @@ -186,10 +109,8 @@ func.func @test_init_tile_with_mismatch_memory_space(%a: memref<1024x1024xf16, 3 #wg_map_1 = #xetile.wg_map // expected-error@+1 {{expect integer array of size 2 for sg_data}} #wg_map_2 = #xetile.wg_map -// expected-error@+1 {{expect integer array of size 2 for non empty inner_blocks attribute}} -#wg_map_3 = #xetile.tile_attr // expected-error@+1 {{expect integer array of size 2 for order}} -#wg_map_4 = #xetile.tile_attr +#wg_map_3 = #xetile.tile_attr // ----- diff --git a/test/Dialect/XeTile/IR/ops.mlir b/test/Dialect/XeTile/IR/ops.mlir index 939d8fed3..0a7e22b2d 100644 --- a/test/Dialect/XeTile/IR/ops.mlir +++ b/test/Dialect/XeTile/IR/ops.mlir @@ -7,7 +7,6 @@ #sg_map = #xetile.sg_map #wg_map = #xetile.wg_map #tile_attr = #xetile.tile_attr -#tile_attr_w_inner_blocks = #xetile.tile_attr #tile_attr_w_order = #xetile.tile_attr @@ -158,8 +157,7 @@ func.func @test_init_tile_using_addr(%src: i64, %dim0_size : index, %dim1_size : // CHECK-LABEL: func @test_load_tile({{.*}}) { func.func @test_load_tile(%src: !xetile.tile<64x32xf16>, %src1 : !xetile.tile<128x128xf16, #tile_attr>, - %src2 : !xetile.tile<64x64xf16, #tile_attr_w_inner_blocks>, - %src3 : !xetile.tile<64x32xf16, #tile_attr_w_order>) { + %src2 : !xetile.tile<64x32xf16, #tile_attr_w_order>) { // CHECK: xetile.load_tile // CHECK-SAME: : !xetile.tile<64x32xf16> -> vector<64x32xf16> %1 = xetile.load_tile %src : !xetile.tile<64x32xf16> -> vector<64x32xf16> @@ -175,16 +173,10 @@ func.func @test_load_tile(%src: !xetile.tile<64x32xf16>, %src1 : !xetile.tile<12 %6 = xetile.load_tile %src1 { padding = 0.1 : f32 } : !xetile.tile<128x128xf16, #tile_attr> -> vector<128x128xf16> - // CHECK: xetile.load_tile - // CHECK-SAME: : !xetile.tile<64x64xf16, - // CHECK-SAME: #xetile.tile_attr> -> vector<8x4x8x16xf16> - %7 = xetile.load_tile %src2 : !xetile.tile<64x64xf16, #tile_attr_w_inner_blocks> - -> vector<8x4x8x16xf16> - // CHECK: xetile.load_tile // CHECK-SAME: : !xetile.tile<64x32xf16, // CHECK-SAME: #xetile.tile_attr> -> vector<64x32xf16> - %8 = xetile.load_tile %src3 : !xetile.tile<64x32xf16, #tile_attr_w_order> + %8 = xetile.load_tile %src2 : !xetile.tile<64x32xf16, #tile_attr_w_order> -> vector<64x32xf16> return @@ -194,8 +186,7 @@ func.func @test_load_tile(%src: !xetile.tile<64x32xf16>, %src1 : !xetile.tile<12 func.func @test_store_tile(%value1 : vector<64x32xf16>, %value2 : vector<8x4x8x16xf16>, %value3 : vector<128x128xf16>, %dst: !xetile.tile<64x32xf16>, %dst1 : !xetile.tile<128x128xf16, #tile_attr>, - %dst2 : !xetile.tile<64x64xf16, #tile_attr_w_inner_blocks>, - %dst3 : !xetile.tile<64x32xf16, #tile_attr_w_order>) { + %dst2 : !xetile.tile<64x32xf16, #tile_attr_w_order>) { // CHECK: xetile.store_tile // CHECK-SAME: vector<64x32xf16>, !xetile.tile<64x32xf16> @@ -206,13 +197,9 @@ func.func @test_store_tile(%value1 : vector<64x32xf16>, // CHECK-SAME: , wg_map = >> xetile.store_tile %value3, %dst1 : vector<128x128xf16>, !xetile.tile<128x128xf16, #tile_attr> - // CHECK: xetile.store_tile - // CHECK-SAME: vector<8x4x8x16xf16>, !xetile.tile<64x64xf16, #xetile.tile_attr> - xetile.store_tile %value2, %dst2 : vector<8x4x8x16xf16>, !xetile.tile<64x64xf16, #tile_attr_w_inner_blocks> - // CHECK: xetile.store_tile // CHECK-SAME: vector<64x32xf16>, !xetile.tile<64x32xf16, #xetile.tile_attr> - xetile.store_tile %value1, %dst3 : vector<64x32xf16>, !xetile.tile<64x32xf16, #tile_attr_w_order> + xetile.store_tile %value1, %dst2 : vector<64x32xf16>, !xetile.tile<64x32xf16, #tile_attr_w_order> return } @@ -233,10 +220,7 @@ func.func @test_prefetch_tile(%src: !xetile.tile<64x64xf16>, %src1: !xetile.tile // CHECK-LABEL: func @test_tile_mma({{.*}}) { -func.func @test_tile_mma(%a: !xetile.tile<64x32xf16>, %b: !xetile.tile<32x128xf16>, %c : !xetile.tile<64x128xf16>, - %a_tiled: !xetile.tile<64x32xf16, #xetile.tile_attr>, - %b_tiled: !xetile.tile<32x128xf16, #xetile.tile_attr>, - %c_tiled: !xetile.tile<64x128xf16, #xetile.tile_attr>) { +func.func @test_tile_mma(%a: !xetile.tile<64x32xf16>, %b: !xetile.tile<32x128xf16>, %c : !xetile.tile<64x128xf16>) { // CHECK: xetile.load_tile // CHECK-SAME: : !xetile.tile<64x32xf16> -> vector<64x32xf16> @@ -260,31 +244,6 @@ func.func @test_tile_mma(%a: !xetile.tile<64x32xf16>, %b: !xetile.tile<32x128xf1 %c_new_ = xetile.tile_mma %a_vec, %b_vec, %c_vec : vector<64x32xf16>, vector<32x128xf16>, vector<64x128xf16> -> vector<64x128xf16> - // CHECK: xetile.load_tile - // CHECK-SAME: !xetile.tile<64x32xf16, #xetile.tile_attr> -> vector<8x4x8x8xf16> - %a_vec_1 = xetile.load_tile %a_tiled: !xetile.tile<64x32xf16, #xetile.tile_attr> - -> vector<8x4x8x8xf16> - - // CHECK: xetile.load_tile - // CHECK-SAME: !xetile.tile<32x128xf16, #xetile.tile_attr> -> vector<4x8x8x16xf16> - %b_vec_1 = xetile.load_tile %b_tiled: !xetile.tile<32x128xf16, #xetile.tile_attr> - -> vector<4x8x8x16xf16> - - // CHECK: xetile.load_tile - // CHECK-SAME: !xetile.tile<64x128xf16, #xetile.tile_attr> -> vector<8x8x8x16xf16> - %c_vec_1 = xetile.load_tile %c_tiled: !xetile.tile<64x128xf16, #xetile.tile_attr> - -> vector<8x8x8x16xf16> - - // CHECK: xetile.tile_mma - // CHECK-SAME: vector<8x4x8x8xf16>, vector<4x8x8x16xf16> -> vector<8x8x8x16xf16> - %c_new_1 = xetile.tile_mma %a_vec_1, %b_vec_1 - : vector<8x4x8x8xf16>, vector<4x8x8x16xf16> -> vector<8x8x8x16xf16> - - // CHECK: xetile.tile_mma - // CHECK-SAME: vector<8x4x8x8xf16>, vector<4x8x8x16xf16>, vector<8x8x8x16xf16> -> vector<8x8x8x16xf16> - %c_new_1_ = xetile.tile_mma %a_vec_1, %b_vec_1, %c_vec_1 - : vector<8x4x8x8xf16>, vector<4x8x8x16xf16>, vector<8x8x8x16xf16> -> vector<8x8x8x16xf16> - return } @@ -347,24 +306,6 @@ func.func @test_update_tile_offset_scattered(%a: memref<1024xf16>, %indices: vec return } -// CHECK-LABEL: func @test_tile_pack({{.*}}) { -func.func @test_tile_pack(%source : vector<32x64xf16>) { - // CHECK: xetile.tile_pack - // CHECK-SAME: {inner_blocks = array} - // CHECK-SAME: vector<32x64xf16> -> vector<2x4x16x16xf16> - %1 = xetile.tile_pack %source {inner_blocks = array} : vector<32x64xf16> -> vector<2x4x16x16xf16> - return -} - -// CHECK-LABEL: func @test_tile_unpack({{.*}}) { -func.func @test_tile_unpack(%source : vector<2x4x16x16xf16>) { - // CHECK: xetile.tile_unpack - // CHECK-SAME: {inner_blocks = array} - // CHECK-SAME: vector<2x4x16x16xf16> -> vector<32x64xf16> - %1 = xetile.tile_unpack %source {inner_blocks = array} : vector<2x4x16x16xf16> -> vector<32x64xf16> - return -} - // CHECK-LABEL: func @test_atomic_rmw({{.*}}) { func.func @test_atomic_rmw(%tile : !xetile.tile<8x16xf16>, %value : vector<8x16xf16>) { // CHECK: xetile.atomic_rmw addf diff --git a/test/Dialect/XeTile/Transforms/WgToSg/broadcast.mlir b/test/Dialect/XeTile/Transforms/WgToSg/broadcast.mlir index 090d135df..999c46dc6 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/broadcast.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/broadcast.mlir @@ -12,13 +12,13 @@ gpu.module @test_broadcast { %c32 = arith.constant 32 : index %cst = arith.constant {map = #xetile.wg_map} dense<0.000000e+00> : vector<256x512xf32> %c0 = arith.constant 0 : index - %0 = xetile.init_tile %arg0[%c0, %c0] : memref<256x384xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>> - %1 = xetile.init_tile %arg1[%c0, %c0] : memref<1x384xf16> -> !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>> - %2:3 = scf.for %arg15 = %c0 to %c384 step %c32 iter_args(%arg16 = %cst, %arg17 = %0, %arg18 = %1) -> (vector<256x512xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>>, !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>>) { - %4 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>> - %5 = xetile.update_tile_offset %arg17, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>> - %6 = xetile.load_tile %arg17 { padding = 0.000000e+00 : f32 } : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>> -> vector<256x32xf16> - %7 = xetile.load_tile %arg18 { padding = 0.000000e+00 : f32 } : !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>> -> vector<1x32xf16> + %0 = xetile.init_tile %arg0[%c0, %c0] : memref<256x384xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr>> + %1 = xetile.init_tile %arg1[%c0, %c0] : memref<1x384xf16> -> !xetile.tile<1x32xf16, #xetile.tile_attr>> + %2:3 = scf.for %arg15 = %c0 to %c384 step %c32 iter_args(%arg16 = %cst, %arg17 = %0, %arg18 = %1) -> (vector<256x512xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr>>, !xetile.tile<1x32xf16, #xetile.tile_attr>>) { + %4 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<1x32xf16, #xetile.tile_attr>> + %5 = xetile.update_tile_offset %arg17, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr>> + %6 = xetile.load_tile %arg17 { padding = 0.000000e+00 : f32 } : !xetile.tile<256x32xf16, #xetile.tile_attr>> -> vector<256x32xf16> + %7 = xetile.load_tile %arg18 { padding = 0.000000e+00 : f32 } : !xetile.tile<1x32xf16, #xetile.tile_attr>> -> vector<1x32xf16> //CHECK: %[[TRANSPOSE:.*]] = vector.transpose {{%.*}}, [1, 0] : vector<1x32xf16> to vector<32x1xf16> %8 = vector.transpose %7, [1, 0] {map = #xetile.wg_map} : vector<1x32xf16> to vector<32x1xf16> //CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[TRANSPOSE]] : vector<32x1xf16> to vector<32x64xf16> @@ -27,10 +27,10 @@ gpu.module @test_broadcast { %10 = xetile.tile_mma %6, %9, %cst {wg_map_a =#xetile.wg_map, wg_map_b =#xetile.wg_map, wg_map_c =#xetile.wg_map} : vector<256x32xf16>, vector<32x512xf16>, vector<256x512xf32> -> vector<256x512xf32> xegpu.compile_hint %11 = arith.addf %arg16, %10 {map = #xetile.wg_map} : vector<256x512xf32> - scf.yield %11, %5, %4 : vector<256x512xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = []>>, !xetile.tile<1x32xf16, #xetile.tile_attr, inner_blocks = []>> + scf.yield %11, %5, %4 : vector<256x512xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr>>, !xetile.tile<1x32xf16, #xetile.tile_attr>> } - %3 = xetile.init_tile %arg2[%c0, %c0] : memref<256x512xf32> -> !xetile.tile<256x512xf32, #xetile.tile_attr, inner_blocks = []>> - xetile.store_tile %2#0, %3 : vector<256x512xf32>, !xetile.tile<256x512xf32, #xetile.tile_attr, inner_blocks = []>> + %3 = xetile.init_tile %arg2[%c0, %c0] : memref<256x512xf32> -> !xetile.tile<256x512xf32, #xetile.tile_attr>> + xetile.store_tile %2#0, %3 : vector<256x512xf32>, !xetile.tile<256x512xf32, #xetile.tile_attr>> gpu.terminator } gpu.return diff --git a/test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir b/test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir index 8d3d552b1..d1bedbe3c 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/btranspose.mlir @@ -51,11 +51,11 @@ gpu.module @test_gemm_btranspose{ %4 = arith.muli %block_id_x, %c2048 : index %5 = arith.muli %0, %c256 : index %6 = arith.addi %4, %5 : index - %7 = xetile.init_tile %arg2[%6, %3] : memref<16384x1536xf32> -> !xetile.tile<256x256xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> + %7 = xetile.init_tile %arg2[%6, %3] : memref<16384x1536xf32> -> !xetile.tile<256x256xf32, #xetile.tile_attr, memory_space = 0 : i32>> %8 = arith.muli %block_id_x, %c2048 : index %9 = arith.muli %0, %c256 : index %10 = arith.addi %8, %9 : index - %11 = xetile.init_tile %arg0[%10, %c0] : memref<16384x12288xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> + %11 = xetile.init_tile %arg0[%10, %c0] : memref<16384x12288xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> //CHECK: %[[R7:.*]] = index.floordivs %[[R6]], %[[c8]] //CHECK: %[[R8:.*]] = index.remu %[[R6]], %[[c8]] @@ -69,16 +69,16 @@ gpu.module @test_gemm_btranspose{ //CHECK: %[[R16:.*]] = index.add %[[R15]], %[[c0]] //CHECK: %[[INITTILE:.*]] = xetile.init_tile %[[arg1]][%[[R12]], %[[R16]]] : memref<1536x12288xf16> -> !xetile.tile<64x32xf16> - %12 = xetile.init_tile %arg1[%2, %c0] : memref<1536x12288xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> - %13:2 = scf.for %arg15 = %c0 to %c2 step %c1_1 iter_args(%arg16 = %7, %arg17 = %11) -> (!xetile.tile<256x256xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>>) { - %14 = xetile.update_tile_offset %arg17, [%c1024, %c0] : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> - %15 = xetile.update_tile_offset %arg16, [%c1024, %c0] : !xetile.tile<256x256xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> - %16:3 = scf.for %arg18 = %c0 to %c12288 step %c32_2 iter_args(%arg19 = %cst, %arg20 = %arg17, %arg21 = %12) -> (vector<256x256xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>>) { - %18 = xetile.update_tile_offset %arg21, [%c0, %c32_2] : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> - %19 = xetile.update_tile_offset %arg20, [%c0, %c32_2] : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> - %20 = xetile.load_tile %arg20 {padding = 0.000000e+00 : f32} : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> -> vector<256x32xf16> + %12 = xetile.init_tile %arg1[%2, %c0] : memref<1536x12288xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> + %13:2 = scf.for %arg15 = %c0 to %c2 step %c1_1 iter_args(%arg16 = %7, %arg17 = %11) -> (!xetile.tile<256x256xf32, #xetile.tile_attr, memory_space = 0 : i32>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>>) { + %14 = xetile.update_tile_offset %arg17, [%c1024, %c0] : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> + %15 = xetile.update_tile_offset %arg16, [%c1024, %c0] : !xetile.tile<256x256xf32, #xetile.tile_attr, memory_space = 0 : i32>> + %16:3 = scf.for %arg18 = %c0 to %c12288 step %c32_2 iter_args(%arg19 = %cst, %arg20 = %arg17, %arg21 = %12) -> (vector<256x256xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>>) { + %18 = xetile.update_tile_offset %arg21, [%c0, %c32_2] : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> + %19 = xetile.update_tile_offset %arg20, [%c0, %c32_2] : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> + %20 = xetile.load_tile %arg20 {padding = 0.000000e+00 : f32} : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> -> vector<256x32xf16> %21 = math.exp %20 {map = #xetile.wg_map} : vector<256x32xf16> - %22 = xetile.load_tile %arg21 {padding = 0.000000e+00 : f32} : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> -> vector<256x32xf16> + %22 = xetile.load_tile %arg21 {padding = 0.000000e+00 : f32} : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> -> vector<256x32xf16> //CHECK: %[[TRANSPOSE:.*]] vector.transpose {{%.*}}, [1, 0] : vector<64x32xf16> to vector<32x64xf16> %23 = vector.transpose %22, [1, 0] {map = #xetile.wg_map} : vector<256x32xf16> to vector<32x256xf16> %24 = math.exp %23 {map = #xetile.wg_map} : vector<32x256xf16> @@ -86,11 +86,11 @@ gpu.module @test_gemm_btranspose{ %25 = xetile.tile_mma %21, %24, %cst {wg_map_a =#xetile.wg_map, wg_map_b =#xetile.wg_map, wg_map_c =#xetile.wg_map} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf32> -> vector<256x256xf32> xegpu.compile_hint %26 = arith.addf %arg19, %25 {map = #xetile.wg_map} : vector<256x256xf32> - scf.yield %26, %19, %18 : vector<256x256xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> + scf.yield %26, %19, %18 : vector<256x256xf32>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> } %17 = math.exp %16#0 {map = #xetile.wg_map} : vector<256x256xf32> - xetile.store_tile %17, %arg16 : vector<256x256xf32>, !xetile.tile<256x256xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> - scf.yield %15, %14 : !xetile.tile<256x256xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32>> + xetile.store_tile %17, %arg16 : vector<256x256xf32>, !xetile.tile<256x256xf32, #xetile.tile_attr, memory_space = 0 : i32>> + scf.yield %15, %14 : !xetile.tile<256x256xf32, #xetile.tile_attr, memory_space = 0 : i32>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32>> } gpu.terminator } diff --git a/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch.mlir index 6bb1df0ec..1f66275db 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch.mlir @@ -40,22 +40,22 @@ module attributes {gpu.container_module} { %13 = arith.remsi %8, %c2 : index //CHECK: %[[INITTILE:.*]] = xetile.init_tile {{%.*}}[{{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}] : memref<4x3x2x128x96xf16> -> !xetile.tile<32x32xf16> //CHECK: %[[INITTILE:.*]] = xetile.init_tile {{%.*}}[{{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}] : memref<4x3x2x64x96xf16> -> !xetile.tile<32x32xf16> - %14 = xetile.init_tile %arg0[%10, %12, %13, %6, %c0] : memref<4x3x2x128x96xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %15 = xetile.init_tile %arg1[%10, %12, %13, %7, %c0] : memref<4x3x2x64x96xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %16:3 = scf.for %arg4 = %c0 to %c96 step %c32 iter_args(%arg5 = %cst, %arg6 = %14, %arg7 = %15) -> (vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) { - %18 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %19 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %20 = xetile.load_tile %arg6 : !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<64x32xf16> - %21 = xetile.load_tile %arg7 : !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<32x32xf16> + %14 = xetile.init_tile %arg0[%10, %12, %13, %6, %c0] : memref<4x3x2x128x96xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %15 = xetile.init_tile %arg1[%10, %12, %13, %7, %c0] : memref<4x3x2x64x96xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %16:3 = scf.for %arg4 = %c0 to %c96 step %c32 iter_args(%arg5 = %cst, %arg6 = %14, %arg7 = %15) -> (vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>) { + %18 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %19 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %20 = xetile.load_tile %arg6 : !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<64x32xf16> + %21 = xetile.load_tile %arg7 : !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<32x32xf16> %22 = vector.transpose %21, [1, 0] {map = #xetile.wg_map} : vector<32x32xf16> to vector<32x32xf16> xegpu.compile_hint %23 = xetile.tile_mma %20, %22, %cst {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<64x32xf16>, vector<32x32xf16>, vector<64x32xf32> -> vector<64x32xf32> xegpu.compile_hint %24 = arith.addf %arg5, %23 {map = #xetile.wg_map} : vector<64x32xf32> - scf.yield %24, %19, %18 : vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + scf.yield %24, %19, %18 : vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> } - %17 = xetile.init_tile %arg2[%10, %12, %13, %6, %7] : memref<4x3x2x128x64xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - xetile.store_tile %16#0, %17 : vector<64x32xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %17 = xetile.init_tile %arg2[%10, %12, %13, %6, %7] : memref<4x3x2x128x64xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + xetile.store_tile %16#0, %17 : vector<64x32xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> } gpu.return } diff --git a/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch_oob.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch_oob.mlir index 1d15cea31..2e574f873 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch_oob.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/gemm_batch_oob.mlir @@ -44,22 +44,22 @@ module attributes {gpu.container_module} { %15 = arith.remsi %10, %c3 : index //CHECK: %[[INITTILE:.*]] = xetile.init_tile {{%.*}}[{{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}] : memref<2x3x3x128x96xf16> -> !xetile.tile<32x32xf16> //CHECK: %[[INITTILE:.*]] = xetile.init_tile {{%.*}}[{{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}}] : memref<2x3x3x64x96xf16> -> !xetile.tile<32x32xf16> - %16 = xetile.init_tile %arg0[%12, %14, %15, %8, %c0] : memref<2x3x3x128x96xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %17 = xetile.init_tile %arg1[%12, %14, %15, %9, %c0] : memref<2x3x3x64x96xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %18:3 = scf.for %arg4 = %c0 to %c96 step %c32 iter_args(%arg5 = %cst, %arg6 = %16, %arg7 = %17) -> (vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) { - %20 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %21 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %22 = xetile.load_tile %arg6 : !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<64x32xf16> - %23 = xetile.load_tile %arg7 : !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<32x32xf16> + %16 = xetile.init_tile %arg0[%12, %14, %15, %8, %c0] : memref<2x3x3x128x96xf16> -> !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %17 = xetile.init_tile %arg1[%12, %14, %15, %9, %c0] : memref<2x3x3x64x96xf16> -> !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %18:3 = scf.for %arg4 = %c0 to %c96 step %c32 iter_args(%arg5 = %cst, %arg6 = %16, %arg7 = %17) -> (vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>) { + %20 = xetile.update_tile_offset %arg7, [%c0, %c32] : !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %21 = xetile.update_tile_offset %arg6, [%c0, %c32] : !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %22 = xetile.load_tile %arg6 : !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<64x32xf16> + %23 = xetile.load_tile %arg7 : !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<32x32xf16> %24 = vector.transpose %23, [1, 0] {map = #xetile.wg_map} : vector<32x32xf16> to vector<32x32xf16> xegpu.compile_hint %25 = xetile.tile_mma %22, %24, %cst {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<64x32xf16>, vector<32x32xf16>, vector<64x32xf32> -> vector<64x32xf32> xegpu.compile_hint %26 = arith.addf %arg5, %25 {map = #xetile.wg_map} : vector<64x32xf32> - scf.yield %26, %21, %20 : vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + scf.yield %26, %21, %20 : vector<64x32xf32>, !xetile.tile<64x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<32x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> } - %19 = xetile.init_tile %arg2[%12, %14, %15, %8, %9] : memref<2x3x3x128x64xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - xetile.store_tile %18#0, %19 : vector<64x32xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %19 = xetile.init_tile %arg2[%12, %14, %15, %8, %9] : memref<2x3x3x128x64xf32> -> !xetile.tile<64x32xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + xetile.store_tile %18#0, %19 : vector<64x32xf32>, !xetile.tile<64x32xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> } gpu.return } diff --git a/test/Dialect/XeTile/Transforms/WgToSg/gemm_postop.mlir b/test/Dialect/XeTile/Transforms/WgToSg/gemm_postop.mlir index d4721b1cb..2a211050b 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/gemm_postop.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/gemm_postop.mlir @@ -1,11 +1,11 @@ // RUN: imex-opt --split-input-file --xetile-wg-to-sg --cse %s -verify-diagnostics | FileCheck %s #wg_map_a = #xetile.wg_map -#tile_attr_a = #xetile.tile_attr +#tile_attr_a = #xetile.tile_attr #wg_map_b = #xetile.wg_map -#tile_attr_b = #xetile.tile_attr +#tile_attr_b = #xetile.tile_attr #wg_map_c = #xetile.wg_map -#tile_attr_c = #xetile.tile_attr +#tile_attr_c = #xetile.tile_attr #map = affine_map<() -> (0)> #map1 = affine_map<() -> (12288)> diff --git a/test/Dialect/XeTile/Transforms/WgToSg/prefetch.mlir b/test/Dialect/XeTile/Transforms/WgToSg/prefetch.mlir index a909a63e5..9248a87a2 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/prefetch.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/prefetch.mlir @@ -20,54 +20,54 @@ gpu.module @test_prefetch{ %0 = arith.muli %block_id_x, %c256 : index %1 = arith.muli %block_id_y, %c128 : index %2 = arith.addi %0, %1 : index - %3 = xetile.init_tile %arg0[%2, %c0] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %4 = xetile.init_tile %arg1[%c0, %c0] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %5:2 = scf.for %arg16 = %c0 to %c320 step %c32 iter_args(%arg17 = %3, %arg18 = %4) -> (!xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) { - %18 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %19 = xetile.update_tile_offset %arg17, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %3 = xetile.init_tile %arg0[%2, %c0] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %4 = xetile.init_tile %arg1[%c0, %c0] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %5:2 = scf.for %arg16 = %c0 to %c320 step %c32 iter_args(%arg17 = %3, %arg18 = %4) -> (!xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>) { + %18 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %19 = xetile.update_tile_offset %arg17, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> //CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16> //CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<8x32xf16> - xetile.prefetch_tile %arg17 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - xetile.prefetch_tile %arg18 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - scf.yield %19, %18 : !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + xetile.prefetch_tile %arg17 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + xetile.prefetch_tile %arg18 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + scf.yield %19, %18 : !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> } - %6 = xetile.init_tile %arg0[%2, %c320] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %7 = xetile.init_tile %arg1[%c0, %c320] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %8 = xetile.init_tile %arg0[%2, %c0] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %9 = xetile.init_tile %arg1[%c0, %c0] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %10:5 = scf.for %arg16 = %c0 to %c4096 step %c32 iter_args(%arg17 = %cst, %arg18 = %6, %arg19 = %7, %arg20 = %8, %arg21 = %9) -> (vector<128x256xf32>, !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>) { - %18 = xetile.update_tile_offset %arg21, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %19 = xetile.update_tile_offset %arg20, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %20 = xetile.update_tile_offset %arg19, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %21 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %6 = xetile.init_tile %arg0[%2, %c320] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %7 = xetile.init_tile %arg1[%c0, %c320] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %8 = xetile.init_tile %arg0[%2, %c0] : memref<512x4096xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %9 = xetile.init_tile %arg1[%c0, %c0] : memref<256x4096xf16> -> !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %10:5 = scf.for %arg16 = %c0 to %c4096 step %c32 iter_args(%arg17 = %cst, %arg18 = %6, %arg19 = %7, %arg20 = %8, %arg21 = %9) -> (vector<128x256xf32>, !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>) { + %18 = xetile.update_tile_offset %arg21, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %19 = xetile.update_tile_offset %arg20, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %20 = xetile.update_tile_offset %arg19, [%c0, %c32] : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %21 = xetile.update_tile_offset %arg18, [%c0, %c32] : !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> %22 = arith.addi %arg16, %c320 : index %23 = arith.cmpi sge, %22, %c4096 : index scf.if %23 { } else { //CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<4x32xf16> //CHECK: xetile.prefetch_tile {{%.*}} {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<8x32xf16> - xetile.prefetch_tile %arg18 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - xetile.prefetch_tile %arg19 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + xetile.prefetch_tile %arg18 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + xetile.prefetch_tile %arg19 {l1_hint = #xetile.cache_hint, l2_hint = #xetile.cache_hint} : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> } - %24 = xetile.load_tile %arg20 : !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<128x32xf16> + %24 = xetile.load_tile %arg20 : !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<128x32xf16> %25 = arith.addf %24, %24 {map = #xetile.wg_map} : vector<128x32xf16> - %26 = xetile.load_tile %arg21 : !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<256x32xf16> + %26 = xetile.load_tile %arg21 : !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<256x32xf16> %27 = vector.transpose %26, [1, 0] {map = #xetile.wg_map} : vector<256x32xf16> to vector<32x256xf16> %28 = math.exp %27 {map = #xetile.wg_map} : vector<32x256xf16> xegpu.compile_hint %29 = xetile.tile_mma %25, %28, %cst {wg_map_a = #xetile.wg_map, wg_map_b = #xetile.wg_map, wg_map_c = #xetile.wg_map} : vector<128x32xf16>, vector<32x256xf16>, vector<128x256xf32> -> vector<128x256xf32> xegpu.compile_hint %30 = arith.addf %arg17, %29 {map = #xetile.wg_map} : vector<128x256xf32> - scf.yield %30, %21, %20, %19, %18 : vector<128x256xf32>, !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + scf.yield %30, %21, %20, %19, %18 : vector<128x256xf32>, !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>>, !xetile.tile<256x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> } %11 = arith.muli %block_id_x, %c256 : index %12 = arith.muli %block_id_y, %c128 : index %13 = arith.addi %11, %12 : index - %14 = xetile.init_tile %arg2[%13, %c0] : memref<512x256xf32> -> !xetile.tile<128x256xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %15 = xetile.load_tile %14 : !xetile.tile<128x256xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<128x256xf32> + %14 = xetile.init_tile %arg2[%13, %c0] : memref<512x256xf32> -> !xetile.tile<128x256xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %15 = xetile.load_tile %14 : !xetile.tile<128x256xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<128x256xf32> %16 = arith.addf %10#0, %15 {map = #xetile.wg_map} : vector<128x256xf32> - %17 = xetile.init_tile %arg3[%13, %c0] : memref<512x256xf32> -> !xetile.tile<128x256xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - xetile.store_tile %16, %17 : vector<128x256xf32>, !xetile.tile<128x256xf32, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> + %17 = xetile.init_tile %arg3[%13, %c0] : memref<512x256xf32> -> !xetile.tile<128x256xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + xetile.store_tile %16, %17 : vector<128x256xf32>, !xetile.tile<128x256xf32, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> gpu.terminator } gpu.return diff --git a/test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir b/test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir index ebb120480..34734b196 100644 --- a/test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir +++ b/test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir @@ -3,8 +3,8 @@ gpu.module @test_arith_extf { gpu.func @test_kernel(%arg0: memref<128x32xf16>) { %c0 = arith.constant 0 : index - %tile = xetile.init_tile %arg0[%c0, %c0] : memref<128x32xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> - %load_tile = xetile.load_tile %tile : !xetile.tile<128x32xf16, #xetile.tile_attr, inner_blocks = [], memory_space = 0 : i32, scattered = false>> -> vector<128x32xf16> + %tile = xetile.init_tile %arg0[%c0, %c0] : memref<128x32xf16> -> !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> + %load_tile = xetile.load_tile %tile : !xetile.tile<128x32xf16, #xetile.tile_attr, memory_space = 0 : i32, scattered = false>> -> vector<128x32xf16> //CHECK: arith.extf {{%.*}} : vector<32x32xf16> to vector<32x32xf32> //CHECK: arith.truncf {{%.*}} : vector<32x32xf32> to vector<32x32xf16> %extf = arith.extf %load_tile {map = #xetile.wg_map} : vector<128x32xf16> to vector<128x32xf32>