From 227a0f79ab54cb807fec49362c4973b0fb3ca7af Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Mon, 9 Dec 2024 10:04:43 -0600 Subject: [PATCH] [Blocking] Rewrite blocking pass to generate small 2D Xetile Ops (#978) --- include/imex/Dialect/XeTile/IR/XeTileTypes.td | 4 + .../XeTile/Transforms/BlockingAnalysis.h | 6 +- .../imex/Dialect/XeTile/Transforms/Passes.td | 5 +- include/imex/Utils/XeCommon.h | 24 +- .../XeTileToXeGPU/ArithOpConversion.cpp | 20 +- .../XeTileToXeGPU/XeTileOpConversion.cpp | 117 +- lib/Dialect/XeTile/Transforms/Blocking.cpp | 1778 ++++++++++++++++- .../XeTile/Transforms/BlockingAnalysis.cpp | 89 +- lib/Utils/XeCommon.cpp | 92 +- .../Transforms/Blocking/unit_tests.mlir | 5 + .../Blocking/unit_tests_transform.mlir | 1753 ++++++++++++++++ 11 files changed, 3640 insertions(+), 253 deletions(-) create mode 100644 test/Dialect/XeTile/Transforms/Blocking/unit_tests_transform.mlir diff --git a/include/imex/Dialect/XeTile/IR/XeTileTypes.td b/include/imex/Dialect/XeTile/IR/XeTileTypes.td index b0238266b..7cd2dc4df 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileTypes.td +++ b/include/imex/Dialect/XeTile/IR/XeTileTypes.td @@ -89,6 +89,10 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface], return llvm::cast(cloneWith(getShape(), elementType)); } + TileType clone(llvm::ArrayRef shape) { + return llvm::cast(cloneWith(shape, getElementType())); + } + xetile::SubGroupMapAttr getSgMap() { auto encoding = llvm::dyn_cast_if_present(getEncoding()); if (encoding) diff --git a/include/imex/Dialect/XeTile/Transforms/BlockingAnalysis.h b/include/imex/Dialect/XeTile/Transforms/BlockingAnalysis.h index 48bd95523..e6a4c1faf 100644 --- a/include/imex/Dialect/XeTile/Transforms/BlockingAnalysis.h +++ b/include/imex/Dialect/XeTile/Transforms/BlockingAnalysis.h @@ -41,10 +41,6 @@ class Block { llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Block blk); -// A pair of operator and operand index number representing -// the use point of a value. -typedef std::pair UsePoint; - class BlockingAnalysis { public: explicit BlockingAnalysis(std::shared_ptr uArch) { @@ -54,7 +50,7 @@ class BlockingAnalysis { mlir::LogicalResult run(mlir::Operation *op); - Block getUseBlockSize(mlir::Value val, UsePoint point) const; + Block getUseBlockSize(mlir::Value val, mlir::OpOperand &point) const; Block getDefBlockSize(mlir::Value val) const; void printAnalysisResult(); diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.td b/include/imex/Dialect/XeTile/Transforms/Passes.td index 83c141718..4f4883aec 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.td +++ b/include/imex/Dialect/XeTile/Transforms/Passes.td @@ -91,7 +91,10 @@ def XeTileBlocking : Pass<"xetile-blocking", "::mlir::gpu::GPUModuleOp">{ let options = [ Option<"device", "device", "std::string", /*default=*/"\"pvc\"", - "gpu platform architecture where these ops are running"> + "gpu platform architecture where these ops are running">, + Option<"EnableTransform", "enable-2d-transform", "bool", + /*default=*/"false", + "Using 2D transform or 4D Conversion."> ]; } diff --git a/include/imex/Utils/XeCommon.h b/include/imex/Utils/XeCommon.h index 77fdfdcd2..797f6be2d 100644 --- a/include/imex/Utils/XeCommon.h +++ b/include/imex/Utils/XeCommon.h @@ -28,8 +28,18 @@ #include #include using namespace mlir::xegpu; + namespace imex { +using PackFuncTy = std::function( + mlir::Value, mlir::Value, mlir::Location, mlir::OpBuilder &)>; + +// A wrapper function to merge small vectors into a big one. It takes a +// range of mlir::Value objects with mlir::VectorType, and merge them +// into a big vector using the provided transformation function. +mlir::Value packVectorsWith(mlir::ValueRange ins, PackFuncTy op, + mlir::Location loc, mlir::OpBuilder &builder); + // Combine vectors vertically while keeping the logical data layout. // As an example, given two vectors (2x4xf16) p and q, it will merge // them in to a 4x4xf16 vector. @@ -40,7 +50,19 @@ namespace imex { // q5, q6, q7, q8 mlir::TypedValue stack(mlir::Value vecUp, mlir::Value vecDown, mlir::Location loc, - mlir::PatternRewriter &rewriter); + mlir::OpBuilder &builder); + +// merge vectors horizontally while keep the logical data layout. +// 1 2 3 4 + 10 11 12 = 1 2 3 4 10 11 12 +// 5 6 7 8 13 14 15 5 6 7 8 13 14 15 +// since there is no direct op in mlir exists, we will +// using ShapeCast and Shuffle to mimic it. It comes with +// cost of complex shuffle masks. the mask for the above one +// will be like this: 0 1 2 3 8 9 10 +// 4 5 6 7 11 12 13 +mlir::TypedValue concat(mlir::Value lhs, mlir::Value rhs, + mlir::Location loc, + mlir::OpBuilder &builder); // It checks each GPUFuncOp in the module to see // whether they have arguments and outputs with diff --git a/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp index 1102a8b64..08aa74a6b 100644 --- a/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp @@ -17,21 +17,6 @@ namespace imex { -using VectorTypedValue = mlir::TypedValue; -using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location, - mlir::PatternRewriter &); - -// see its description in XeTileOpConversion.cpp -extern VectorTypedValue concat(mlir::Value v1, mlir::Value v2, - mlir::Location loc, - mlir::PatternRewriter &rewriter); - -// see its description in XeTileOpConversion.cpp -extern mlir::Value mergeVectorsWrapper(mlir::ValueRange ins, - std::function transFunc, - mlir::Location loc, - XeOneToNPatternRewriter &rewriter); - static mlir::Value createBinOp(mlir::vector::CombiningKind kind, mlir::Value lhs, mlir::Value rhs, mlir::Type elemTy, mlir::Location &loc, @@ -318,8 +303,7 @@ class SgVectorMultiDimReductionOpPattern // TODO: need a better way to represent the result (align with // unpack/pack logic). currently we just shuffle them and cast it to the // type/shape in xetile program. - auto reducedVal = - mergeVectorsWrapper(intermediates, concat, loc, rewriter); + auto reducedVal = packVectorsWith(intermediates, concat, loc, rewriter); auto targetTy = mlir::VectorType::get({shape[1], shape[3]}, elemTy); auto newOp = rewriter.create(loc, targetTy, reducedVal); @@ -338,7 +322,7 @@ class SgVectorMultiDimReductionOpPattern // currently we just shuffle them and cast it to the type/shape in // xetile program. auto reductionVal = - mergeVectorsWrapper(intermediates, concat, loc, rewriter); + packVectorsWith(intermediates, concat, loc, rewriter); auto targetTy = mlir::VectorType::get({shape[0], shape[2]}, elemTy); auto newOp = rewriter.create(loc, targetTy, reductionVal); diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index 00c7fc5b9..80d6b1fcd 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -38,106 +38,6 @@ using mlir::vector::ShapeCastOp; using mlir::vector::ShuffleOp; using mlir::vector::SplatOp; -using VectorTypedValue = mlir::TypedValue; -using funcTy = VectorTypedValue(mlir::Value, mlir::Value, mlir::Location, - mlir::PatternRewriter &); - -// generate linearized shuffle mask for concat. -static llvm::SmallVector -getShuffleMask(llvm::ArrayRef shape1, llvm::ArrayRef shape2) { - assert(shape1.size() == shape2.size() && shape1.size() <= 2 && - "only 1D/2D shape are supported."); - assert(shape1.drop_back() == shape2.drop_back() && - "the row dim of the shapes should match."); - int64_t size1 = std::accumulate(shape1.begin(), shape1.end(), 1, - std::multiplies()); - int64_t size2 = std::accumulate(shape2.begin(), shape2.end(), 1, - std::multiplies()); - llvm::SmallVector mask(size1 + size2); - auto rows = shape1.size() == 1 ? 1 : shape1[0]; - auto cols1 = shape1.size() == 1 ? shape1[0] : shape1[1]; - auto cols2 = shape2.size() == 1 ? shape2[0] : shape2[1]; - for (int64_t i = 0; i < rows; i++) { - int64_t s = i * (cols1 + cols2); - int64_t m = s + cols1; - int64_t e = m + cols2; - int64_t v1 = i * cols1; - int64_t v2 = size1 + i * cols2; - std::iota(mask.begin() + s, mask.begin() + m, v1); - std::iota(mask.begin() + m, mask.begin() + e, v2); - } - return mask; -} - -// merge vectors horizontally while keep the logical data layout. -// 1 2 3 4 + 10 11 12 = 1 2 3 4 10 11 12 -// 5 6 7 8 13 14 15 5 6 7 8 13 14 15 -// since there is no direct op in mlir exists, we will -// using ShapeCast and Shuffle to mimic it. It comes with -// cost of complex shuffle masks. the mask for the above one -// will be like this: 0 1 2 3 8 9 10 -// 4 5 6 7 11 12 13 -VectorTypedValue concat(mlir::Value vecLeft, mlir::Value vecRight, - mlir::Location loc, mlir::PatternRewriter &rewriter) { - auto vecLeftTy = llvm::cast(vecLeft.getType()); - auto vecRightTy = llvm::cast(vecRight.getType()); - - assert(vecLeftTy.getShape()[0] == vecLeftTy.getShape()[0] && - "Operands of concat() do not have the same number of rows."); - assert(vecLeftTy.getRank() <= 2 && - vecRightTy.getRank() == vecLeftTy.getRank() && - "Currently concat only works on 1D/2D vector."); - - auto elemTy = vecLeftTy.getElementType(); - auto leftSize = vecLeftTy.getNumElements(); - auto leftShape = vecLeftTy.getShape(); - auto leftFlatTy = mlir::VectorType::get({vecLeftTy.getNumElements()}, elemTy); - - auto rightSize = vecRightTy.getNumElements(); - auto rightShape = vecRightTy.getShape(); - auto rightFlatTy = - mlir::VectorType::get({vecRightTy.getNumElements()}, elemTy); - - auto newShape = vecLeftTy.getRank() == 1 - ? llvm::SmallVector({leftSize + rightSize}) - : llvm::SmallVector( - {leftShape[0], leftShape[1] + rightShape[1]}); - auto castLeft = rewriter.create(loc, leftFlatTy, vecLeft); - auto castRight = rewriter.create(loc, rightFlatTy, vecRight); - auto mask = getShuffleMask(leftShape, rightShape); - auto shuffleOp = rewriter.create(loc, castLeft, castRight, mask); - auto targetTy = mlir::VectorType::get(newShape, elemTy); - auto newOp = rewriter.create(loc, targetTy, shuffleOp); - return newOp; -} - -// A wrapper function to merge small vectors into a big one. It takes a -// range of mlir::Value objects with mlir::VectorType, and merge them -// into a big vector using the provided transformation function. -mlir::Value mergeVectorsWrapper(mlir::ValueRange ins, - std::function transFunc, - mlir::Location loc, - XeOneToNPatternRewriter &rewriter) { - llvm::SmallVector shuffleOps(ins.begin(), ins.end()); - while (shuffleOps.size() > 1) { - auto curr = shuffleOps; - shuffleOps.clear(); - size_t currPairStartIdx{0}; - while (currPairStartIdx < curr.size() - 1) { - size_t leftIdx{currPairStartIdx++}; - size_t rightIdx{currPairStartIdx++}; - auto newOp = transFunc(curr[leftIdx], curr[rightIdx], loc, rewriter); - shuffleOps.push_back(newOp); - } - if (currPairStartIdx < curr.size()) { - assert(currPairStartIdx == curr.size() - 1); - shuffleOps.push_back(curr[curr.size() - 1]); - } - } - - return shuffleOps[0]; -} - // Check that lowerUnpackOrPack will be able to evenly combine/split the input // grid into the output grid. static bool isUnpackPackCompatible(xetile::TileUnpackOp unpackOp, @@ -164,7 +64,7 @@ static bool isUnpackPackCompatible(xetile::TileUnpackOp unpackOp, // a unified function lowering Unpack and Pack ops. static llvm::SmallVector -lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op, +lowerUnpackOrPack(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::ValueRange inputs, mlir::DenseI64ArrayAttr inBlkSizes, mlir::DenseI64ArrayAttr outBlkSizes, llvm::ArrayRef inGrids, @@ -183,8 +83,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op, auto idx = i * inGrids[1] + j; valSet.push_back(inputs[idx]); if (valSet.size() == static_cast(nums)) { - auto newOp = - mergeVectorsWrapper(valSet, stack, op->getLoc(), rewriter); + auto newOp = packVectorsWith(valSet, stack, loc, rewriter); intermediates[i / nums * inGrids[1] + j] = newOp; valSet.clear(); } @@ -205,7 +104,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op, for (auto k = 0; k < nums; k++) { llvm::SmallVector offsets({k * blkSizes[0], 0}); auto newOp = rewriter.create( - op->getLoc(), v, offsets, blkSizes, strides); + loc, v, offsets, blkSizes, strides); auto idx = startPos + k * inGrids[1]; intermediates[idx] = newOp; } @@ -228,8 +127,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op, for (auto j = 0; j < interGrids[1]; j++) { valSet.push_back(intermediates[i * interGrids[1] + j]); if (valSet.size() == nums) { - auto newOp = - mergeVectorsWrapper(valSet, concat, op->getLoc(), rewriter); + auto newOp = packVectorsWith(valSet, concat, loc, rewriter); newOps.push_back(newOp); valSet.clear(); } @@ -245,7 +143,7 @@ lowerUnpackOrPack(XeOneToNPatternRewriter &rewriter, mlir::Operation *op, for (int64_t k = 0; k < nums; k++) { llvm::SmallVector offsets({0, k * blkSizes[1]}); auto newOp = rewriter.create( - op->getLoc(), v, offsets, blkSizes, strides); + loc, v, offsets, blkSizes, strides); newOps.push_back(newOp); } } @@ -291,7 +189,7 @@ class SgTileUnpackOpPattern : public XeOneToNConversion { } rewriter.setInsertionPoint(op); - auto newOps = lowerUnpackOrPack(rewriter, op, inputs, inBlkSizes, + auto newOps = lowerUnpackOrPack(rewriter, op->getLoc(), inputs, inBlkSizes, outBlkSizes, inGrids, outGrids); if (op->hasOneUse() && packOp && isUnpackPackCompatible(op, packOp)) { @@ -327,7 +225,8 @@ class SgTilePackOpPattern : public XeOneToNConversion { auto outGrids = outTy.getShape().take_front(2); auto outBlkSizes = op.getInnerBlocksAttr(); - auto newOps = lowerUnpackOrPack(rewriter, op, {input}, inBlkSizes, + rewriter.setInsertionPoint(op); + auto newOps = lowerUnpackOrPack(rewriter, op->getLoc(), {input}, inBlkSizes, outBlkSizes, inGrids, outGrids); // it is simple one-to-one mapping diff --git a/lib/Dialect/XeTile/Transforms/Blocking.cpp b/lib/Dialect/XeTile/Transforms/Blocking.cpp index 3cc88910b..88c9ef616 100644 --- a/lib/Dialect/XeTile/Transforms/Blocking.cpp +++ b/lib/Dialect/XeTile/Transforms/Blocking.cpp @@ -31,9 +31,12 @@ #include #include #include +#include #include #include #include +#include +#include #include #include @@ -58,9 +61,1506 @@ namespace imex { #include "imex/Dialect/XeTile/Transforms/Passes.h.inc" } // namespace imex +// TODO: Remove it after consolidation +bool Enable2DBlockingTransform = false; + namespace imex { +// Blocking is to decompose ops working on big tile or vector size +// into a set of ops working on smaller tile or vector size, that +// can be mapped to hardware instructions. The old implementation +// is using a 4D tile/vector type to represent the blocking result. +// with the outer 2 dimensions corresponding to the grid size or +// the number of instructions and their organization, while the +// inner 2 dimensions corresponding to the block size, which can +// be handled by a single instruction. The new implementation is +// to remove this 4D tile/vector type representation and generating +// a set of xetile or vector ops working the block size directly. namespace Blocking { +template +class RewriteXeTileOp : public mlir::OpRewritePattern { +public: + using OpPatternRewriter = typename mlir::PatternRewriter; + + RewriteXeTileOp(mlir::MLIRContext *context, AnalysisT &analysis) + : mlir::OpRewritePattern(context), analysis(analysis) {} + +protected: + AnalysisT &analysis; +}; + +template