diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.td b/include/imex/Dialect/XeTile/Transforms/Passes.td index 4f4883aec..83c141718 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.td +++ b/include/imex/Dialect/XeTile/Transforms/Passes.td @@ -91,10 +91,7 @@ def XeTileBlocking : Pass<"xetile-blocking", "::mlir::gpu::GPUModuleOp">{ let options = [ Option<"device", "device", "std::string", /*default=*/"\"pvc\"", - "gpu platform architecture where these ops are running">, - Option<"EnableTransform", "enable-2d-transform", "bool", - /*default=*/"false", - "Using 2D transform or 4D Conversion."> + "gpu platform architecture where these ops are running"> ]; } diff --git a/lib/Dialect/XeTile/Transforms/Blocking.cpp b/lib/Dialect/XeTile/Transforms/Blocking.cpp index 2d36cd00a..337f18f22 100644 --- a/lib/Dialect/XeTile/Transforms/Blocking.cpp +++ b/lib/Dialect/XeTile/Transforms/Blocking.cpp @@ -60,9 +60,6 @@ namespace imex { #include "imex/Dialect/XeTile/Transforms/Passes.h.inc" } // namespace imex -// TODO: Remove it after consolidation -bool Enable2DBlockingTransform = false; - namespace imex { // Blocking is to decompose ops working on big tile or vector size // into a set of ops working on smaller tile or vector size, that @@ -106,11 +103,10 @@ static const char *const packAttrName = "__xetile_blocking_pack__"; static const char *const unpackAttrName = "__xetile_blocking_unpack__"; static const char *const blockAttrName = "__xetile_blocking_inner_block__"; -static mlir::Value -unpackWithUnrealizedCastOp(mlir::ValueRange srcs, mlir::Type destTy, - llvm::ArrayRef innerBlock, - mlir::Location loc, - mlir::PatternRewriter &rewriter) { +static mlir::Value addUnpackOp(mlir::ValueRange srcs, mlir::Type destTy, + llvm::ArrayRef innerBlock, + mlir::Location loc, + mlir::PatternRewriter &rewriter) { auto attr = mlir::NamedAttribute(rewriter.getStringAttr(unpackAttrName), rewriter.getUnitAttr()); auto innerBlkAttr = @@ -122,10 +118,10 @@ unpackWithUnrealizedCastOp(mlir::ValueRange srcs, mlir::Type destTy, return castOp.getResult(0); } -static mlir::ValueRange -packWithUnrealizedCastOp(mlir::Value src, mlir::TypeRange destTypes, - llvm::ArrayRef innerBlock, mlir::Location loc, - mlir::PatternRewriter &rewriter) { +static mlir::ValueRange addPackOp(mlir::Value src, mlir::TypeRange destTypes, + llvm::ArrayRef innerBlock, + mlir::Location loc, + mlir::PatternRewriter &rewriter) { auto attr = mlir::NamedAttribute(rewriter.getStringAttr(packAttrName), rewriter.getUnitAttr()); auto innerBlkAttr = @@ -580,8 +576,8 @@ class RewriteArithConstantOp newOps.push_back(newOp); } } - auto castOp = unpackWithUnrealizedCastOp( - newOps, value.getType(), blockSize.asArrayRef(), loc, rewriter); + auto castOp = addUnpackOp(newOps, value.getType(), blockSize.asArrayRef(), + loc, rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); } @@ -620,8 +616,8 @@ class RewriteInitTileOp auto convertedTileTypes = convertTypes(tileTy, blockSize.asArrayRef()); auto newIndicesTypes = convertTypes(indicesTy, blockSize.asArrayRef()); - auto subIndices = packWithUnrealizedCastOp( - indices, newIndicesTypes, blockSize.asArrayRef(), loc, rewriter); + auto subIndices = addPackOp(indices, newIndicesTypes, + blockSize.asArrayRef(), loc, rewriter); for (auto [t, i] : llvm::zip(convertedTileTypes, subIndices)) { llvm::SmallVector operands({op.getSource(), i}); @@ -676,8 +672,8 @@ class RewriteInitTileOp } } } - auto castOp = unpackWithUnrealizedCastOp( - newOps, tileTy, blockSize.asArrayRef(), loc, rewriter); + auto castOp = + addUnpackOp(newOps, tileTy, blockSize.asArrayRef(), loc, rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); } @@ -703,8 +699,8 @@ class RewritePrefetchTileOp if (!blockSize || shape == blockSize.asArrayRef()) return failure(); auto convertedTileTypes = convertTypes(tileTy, blockSize.asArrayRef()); - auto convertedTiles = packWithUnrealizedCastOp( - tile, convertedTileTypes, blockSize.asArrayRef(), loc, rewriter); + auto convertedTiles = addPackOp(tile, convertedTileTypes, + blockSize.asArrayRef(), loc, rewriter); for (auto [t, ty] : llvm::zip_equal(convertedTiles, convertedTileTypes)) { rewriter.create(loc, ty, t, op->getAttrs()); @@ -735,8 +731,8 @@ class RewriteLoadTileOp return failure(); auto convertedTileTypes = convertTypes(tileTy, blockSize.asArrayRef()); - auto convertedTiles = packWithUnrealizedCastOp( - tile, convertedTileTypes, blockSize.asArrayRef(), loc, rewriter); + auto convertedTiles = addPackOp(tile, convertedTileTypes, + blockSize.asArrayRef(), loc, rewriter); auto vecTy = ::mlir::VectorType::get(blockSize.asArrayRef(), tileTy.getElementType()); @@ -748,8 +744,8 @@ class RewriteLoadTileOp newOps.push_back(newOp); } - auto castOp = unpackWithUnrealizedCastOp( - newOps, op.getType(), blockSize.asArrayRef(), loc, rewriter); + auto castOp = addUnpackOp(newOps, op.getType(), blockSize.asArrayRef(), loc, + rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); @@ -779,10 +775,10 @@ class RewriteStoreTileOp auto convertedValTypes = convertTypes(valTy, blockSize.asArrayRef()); auto convertedTileTypes = convertTypes(tileTy, blockSize.asArrayRef()); - auto convertedValues = packWithUnrealizedCastOp( - value, convertedValTypes, blockSize.asArrayRef(), loc, rewriter); - auto convertedTiles = packWithUnrealizedCastOp( - tile, convertedTileTypes, blockSize.asArrayRef(), loc, rewriter); + auto convertedValues = addPackOp(value, convertedValTypes, + blockSize.asArrayRef(), loc, rewriter); + auto convertedTiles = addPackOp(tile, convertedTileTypes, + blockSize.asArrayRef(), loc, rewriter); for (auto [v, t] : llvm::zip(convertedValues, convertedTiles)) { rewriter.create(loc, v, t, op.getL1HintAttr(), @@ -819,10 +815,10 @@ class RewriteLoadGatherOp auto convertedMaskTypes = convertTypes(mask.getType(), blockSize.asArrayRef()); - auto tiles = packWithUnrealizedCastOp( - tile, convertedTileTypes, blockSize.asArrayRef(), loc, rewriter); - auto masks = packWithUnrealizedCastOp( - mask, convertedMaskTypes, blockSize.asArrayRef(), loc, rewriter); + auto tiles = addPackOp(tile, convertedTileTypes, blockSize.asArrayRef(), + loc, rewriter); + auto masks = addPackOp(mask, convertedMaskTypes, blockSize.asArrayRef(), + loc, rewriter); auto newValueTy = mlir::VectorType::get(blockSize.asArrayRef(), elemTy); llvm::SmallVector newOps; for (auto [t, m] : llvm::zip(tiles, masks)) { @@ -832,8 +828,8 @@ class RewriteLoadGatherOp newOps.push_back(newOp); } - auto castOp = unpackWithUnrealizedCastOp( - newOps, op.getType(), blockSize.asArrayRef(), loc, rewriter); + auto castOp = addUnpackOp(newOps, op.getType(), blockSize.asArrayRef(), loc, + rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); } @@ -869,12 +865,12 @@ class RewriteStoreScatterOp auto convertedMaskTypes = convertTypes(mask.getType(), blockSize.asArrayRef()); - auto values = packWithUnrealizedCastOp( - value, convertedValTypes, blockSize.asArrayRef(), loc, rewriter); - auto tiles = packWithUnrealizedCastOp( - tile, convertedTileTypes, blockSize.asArrayRef(), loc, rewriter); - auto masks = packWithUnrealizedCastOp( - mask, convertedMaskTypes, blockSize.asArrayRef(), loc, rewriter); + auto values = addPackOp(value, convertedValTypes, blockSize.asArrayRef(), + loc, rewriter); + auto tiles = addPackOp(tile, convertedTileTypes, blockSize.asArrayRef(), + loc, rewriter); + auto masks = addPackOp(mask, convertedMaskTypes, blockSize.asArrayRef(), + loc, rewriter); for (auto [v, t, m] : llvm::zip(values, tiles, masks)) { (void)rewriter.create( @@ -909,8 +905,8 @@ class RewriteUpdateTileOffsetOp return mlir::failure(); auto convertedTileTypes = convertTypes(tileTy, blockSize.asArrayRef()); - auto convertedTiles = packWithUnrealizedCastOp( - tile, convertedTileTypes, blockSize.asArrayRef(), loc, rewriter); + auto convertedTiles = addPackOp(tile, convertedTileTypes, + blockSize.asArrayRef(), loc, rewriter); llvm::SmallVector newOps; @@ -922,9 +918,8 @@ class RewriteUpdateTileOffsetOp auto convertedIndicesTypes = convertTypes(indicesTy, blockSize.asArrayRef()); - auto convertedIndices = - packWithUnrealizedCastOp(indices, convertedIndicesTypes, - blockSize.asArrayRef(), loc, rewriter); + auto convertedIndices = addPackOp(indices, convertedIndicesTypes, + blockSize.asArrayRef(), loc, rewriter); for (auto [t, i] : llvm::zip(convertedTiles, convertedIndices)) { auto newOp = rewriter.create( @@ -939,8 +934,8 @@ class RewriteUpdateTileOffsetOp } } - auto castOp = unpackWithUnrealizedCastOp( - newOps, op.getType(), blockSize.asArrayRef(), loc, rewriter); + auto castOp = addUnpackOp(newOps, op.getType(), blockSize.asArrayRef(), loc, + rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); } @@ -985,8 +980,7 @@ class RewriteTileMMAOp if (type.getShape() == blockSize) return llvm::SmallVector({val}); auto convertedTypes = convertTypes(type, blockSize); - auto values = packWithUnrealizedCastOp(val, convertedTypes, blockSize, - loc, rewriter); + auto values = addPackOp(val, convertedTypes, blockSize, loc, rewriter); return llvm::to_vector(values); }; @@ -1031,8 +1025,8 @@ class RewriteTileMMAOp newOps.push_back(tmpC); } } - auto castOp = unpackWithUnrealizedCastOp( - newOps, resultTy, Block().asArrayRef(), loc, rewriter); + auto castOp = + addUnpackOp(newOps, resultTy, Block().asArrayRef(), loc, rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); } @@ -1076,8 +1070,8 @@ class RewriteTileReductionOp return rewriter.notifyMatchFailure(op, "Invalid blocking size"); auto convertedSrcTypes = convertTypes(srcTy, blkSize.asArrayRef()); - auto convertedSrcs = packWithUnrealizedCastOp( - src, convertedSrcTypes, blkSize.asArrayRef(), loc, rewriter); + auto convertedSrcs = + addPackOp(src, convertedSrcTypes, blkSize.asArrayRef(), loc, rewriter); int64_t grid[2] = {shape[0] / blkSize[0], shape[1] / blkSize[1]}; @@ -1108,8 +1102,8 @@ class RewriteTileReductionOp } blkSize[dims[0]] = 1; - auto castOp = unpackWithUnrealizedCastOp( - newOps, op.getType(), blkSize.asArrayRef(), loc, rewriter); + auto castOp = + addUnpackOp(newOps, op.getType(), blkSize.asArrayRef(), loc, rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); @@ -1144,8 +1138,8 @@ class RewriteTileBroadcastOp return rewriter.notifyMatchFailure(op, "No need to block"); auto convertedSrcTypes = convertTypes(srcTy, srcBlkSize.asArrayRef()); - auto convertedSrcs = packWithUnrealizedCastOp( - src, convertedSrcTypes, srcBlkSize.asArrayRef(), loc, rewriter); + auto convertedSrcs = addPackOp(src, convertedSrcTypes, + srcBlkSize.asArrayRef(), loc, rewriter); auto resTy = op.getResult().getType(); int64_t resultGrid[2] = {resTy.getShape()[0] / resBlkSize[0], @@ -1202,8 +1196,8 @@ class RewriteTileBroadcastOp } else { return mlir::failure(); } - auto castOp = unpackWithUnrealizedCastOp( - newOps, resTy, resBlkSize.asArrayRef(), loc, rewriter); + auto castOp = + addUnpackOp(newOps, resTy, resBlkSize.asArrayRef(), loc, rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); } @@ -1243,8 +1237,8 @@ class RewriteTileTransposeOp auto convertedResultTypes = convertTypes(resultTy, outBlockSize.asArrayRef()); - auto convertedInputs = packWithUnrealizedCastOp( - input, convertedInputTypes, inBlockSize.asArrayRef(), loc, rewriter); + auto convertedInputs = addPackOp(input, convertedInputTypes, + inBlockSize.asArrayRef(), loc, rewriter); int64_t grids[2] = {resultShape[0] / outBlockSize[0], resultShape[1] / outBlockSize[1]}; @@ -1258,8 +1252,8 @@ class RewriteTileTransposeOp newOps.push_back(res); } } - auto castOp = unpackWithUnrealizedCastOp( - newOps, resultTy, outBlockSize.asArrayRef(), loc, rewriter); + auto castOp = + addUnpackOp(newOps, resultTy, outBlockSize.asArrayRef(), loc, rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); } @@ -1301,8 +1295,8 @@ class RewriteVectorizableOp if (!oprTy || oprTy.getRank() != 2) newOperands.emplace_back(opr); auto convertedTypes = convertTypes(oprTy, blockSize.asArrayRef()); - auto convertedValues = packWithUnrealizedCastOp( - opr, convertedTypes, blockSize.asArrayRef(), loc, rewriter); + auto convertedValues = + addPackOp(opr, convertedTypes, blockSize.asArrayRef(), loc, rewriter); newOperands.push_back(convertedValues); } @@ -1330,8 +1324,8 @@ class RewriteVectorizableOp } } - auto castOp = unpackWithUnrealizedCastOp( - newOps, resType, blockSize.asArrayRef(), loc, rewriter); + auto castOp = + addUnpackOp(newOps, resType, blockSize.asArrayRef(), loc, rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); @@ -1378,8 +1372,8 @@ class RewriteSCFForOp } else { auto newTypes = convertTypes(type, blockSZ.asArrayRef()); argConversion.addInputs(i, newTypes); - auto values = packWithUnrealizedCastOp( - v, newTypes, blockSZ.asArrayRef(), loc, rewriter); + auto values = + addPackOp(v, newTypes, blockSZ.asArrayRef(), loc, rewriter); convertedInitArgs.append(values.begin(), values.end()); } } @@ -1410,7 +1404,7 @@ class RewriteSCFForOp if (!inputMap || inputMap->size == 1) { castArgs.push_back(convertedArgs[inputMap->inputNo]); } else { - auto arg = unpackWithUnrealizedCastOp( + auto arg = addUnpackOp( convertedArgs.slice(inputMap->inputNo, inputMap->size), regionArgs[i].getType(), blockSZs[i].asArrayRef(), loc, rewriter); castArgs.push_back(arg); @@ -1426,7 +1420,7 @@ class RewriteSCFForOp if (!inputMap || inputMap->size == 1) { castResults.push_back(convertedResults[inputMap->inputNo]); } else { - auto res = unpackWithUnrealizedCastOp( + auto res = addUnpackOp( convertedResults.slice(inputMap->inputNo, inputMap->size), results[i].getType(), blockSZs[i].asArrayRef(), loc, rewriter); castResults.push_back(res); @@ -1455,8 +1449,8 @@ class RewriteSCFYieldOp auto type = mlir::dyn_cast(res.getType()); if (blockSZ && type && type.getShape() != blockSZ.asArrayRef()) { auto newTypes = convertTypes(type, blockSZ.asArrayRef()); - auto values = packWithUnrealizedCastOp( - res, newTypes, blockSZ.asArrayRef(), loc, rewriter); + auto values = + addPackOp(res, newTypes, blockSZ.asArrayRef(), loc, rewriter); convertedResults.append(values.begin(), values.end()); } else { convertedResults.push_back(res); @@ -1517,8 +1511,8 @@ class RewriteCreateMaskOp } x = sub(x, blockSize[0]); } - auto castOp = unpackWithUnrealizedCastOp( - newOps, resTy, blockSize.asArrayRef(), loc, rewriter); + auto castOp = + addUnpackOp(newOps, resTy, blockSize.asArrayRef(), loc, rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); } @@ -1549,911 +1543,26 @@ class RewriteSplatOp loc, newTy, op->getOperands(), op->getAttrs()); auto numOps = resTy.getNumElements() / newTy.getNumElements(); llvm::SmallVector newOps(numOps, newOp); - auto castOp = unpackWithUnrealizedCastOp( - newOps, resTy, blockSize.asArrayRef(), loc, rewriter); + auto castOp = + addUnpackOp(newOps, resTy, blockSize.asArrayRef(), loc, rewriter); rewriter.replaceOp(op, castOp); return mlir::success(); } }; - -// ====================== Old 4D blocking patterns ======================== -static xetile::TileUnpackOp -addUnpackOp(mlir::Value src, mlir::ConversionPatternRewriter &rewriter) { - auto srcTy = llvm::dyn_cast_if_present(src.getType()); - assert(srcTy && srcTy.getRank() == 4); - auto shape = srcTy.getShape(); - auto grids = shape.take_front(2); - auto innerBlocks = shape.take_back(2); - llvm::SmallVector unpackShape( - {grids[0] * innerBlocks[0], grids[1] * innerBlocks[1]}); - - auto unpackTy = mlir::VectorType::get(unpackShape, srcTy.getElementType()); - return rewriter.create( - src.getLoc(), unpackTy, src, - mlir::DenseI64ArrayAttr::get(src.getContext(), innerBlocks)); -} - -static mlir::Value addPackOp(mlir::Value src, - llvm::ArrayRef targetBlkSizes, - mlir::ConversionPatternRewriter &rewriter) { - auto srcTy = mlir::dyn_cast(src.getType()); - assert(srcTy && targetBlkSizes.size() == 2); - auto shape = srcTy.getShape(); - llvm::SmallVector packShape({shape[0] / targetBlkSizes[0], - shape[1] / targetBlkSizes[1], - targetBlkSizes[0], targetBlkSizes[1]}); - - auto packTy = mlir::VectorType::get(packShape, srcTy.getElementType()); - auto packOp = rewriter.create( - src.getLoc(), packTy, src, - mlir::DenseI64ArrayAttr::get(src.getContext(), targetBlkSizes)); - return packOp; -} - -/// OpConversionPatternWithAnalysis is a wrapper around OpConversionPattern -/// but takes an extra AnalysisT object as an argument, such that patterns -/// can leverage the analysis results. -template -class OpConversionPatternWithAnalysis - : public mlir::OpConversionPattern { -public: - using OpPatternRewriter = typename mlir::ConversionPatternRewriter; - - OpConversionPatternWithAnalysis(mlir::MLIRContext *context, - AnalysisT &analysis) - : mlir::OpConversionPattern(context), analysis(analysis) {} - -protected: - AnalysisT &analysis; -}; - -/// OpTraitConversionPatternWithAnalysis is a wrapper around -/// OpTraitConversionPattern but takes an extra AnalysisT object as an argument, -/// such that patterns can leverage the analysis results. -template